File size: 3,466 Bytes
c2567b3 9867c26 c2567b3 9867c26 51f6238 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Page config
st.set_page_config(
page_title="Phi-2 QLoRA Chatbot",
page_icon="🤖",
layout="wide"
)
# Initialize session state for chat history
if "messages" not in st.session_state:
st.session_state.messages = []
@st.cache_resource
def load_model():
# Load base model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Load the LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"phi2-qlora-output",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/phi-2",
trust_remote_code=True
)
return model, tokenizer
def generate_response(prompt, model, tokenizer, max_length=512, temperature=0.7):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from the response
response = response[len(prompt):].strip()
return response
# Load the model
try:
model, tokenizer = load_model()
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Chat interface
st.title("Phi-2 QLoRA Chatbot 🤖")
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# Chat input
if prompt := st.chat_input():
# Display user message
with st.chat_message("user"):
st.write(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Generate and display assistant response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = generate_response(prompt, model, tokenizer)
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response})
# Sidebar with model information and example prompts
with st.sidebar:
st.title("About")
st.markdown("""
This chatbot uses a fine-tuned version of the Microsoft Phi-2 model,
trained using the QLoRA technique. The model has been optimized for
specific conversational tasks while maintaining efficiency through
parameter-efficient fine-tuning.
""")
st.title("Example Prompts")
example_prompts = [
"Can you explain how quantum computing works in simple terms?",
"Write a short story about a robot learning to feel emotions.",
"What are the main differences between Python and JavaScript?",
"Give me some tips for improving my public speaking skills.",
"Explain the concept of climate change to a 10-year-old."
]
st.markdown("### Try these prompts to get started:")
for prompt in example_prompts:
st.button(prompt) |