|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Phi-2 QLoRA Chatbot", |
|
|
page_icon="π€", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
"microsoft/phi-2", |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
"phi2-qlora-output", |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"microsoft/phi-2", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def generate_response(prompt, model, tokenizer, max_length=512, temperature=0.7): |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = response[len(prompt):].strip() |
|
|
return response |
|
|
|
|
|
|
|
|
try: |
|
|
model, tokenizer = load_model() |
|
|
st.success("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
st.title("Phi-2 QLoRA Chatbot π€") |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.write(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input(): |
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
response = generate_response(prompt, model, tokenizer) |
|
|
st.write(response) |
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.title("About") |
|
|
st.markdown(""" |
|
|
This chatbot uses a fine-tuned version of the Microsoft Phi-2 model, |
|
|
trained using the QLoRA technique. The model has been optimized for |
|
|
specific conversational tasks while maintaining efficiency through |
|
|
parameter-efficient fine-tuning. |
|
|
""") |
|
|
|
|
|
st.title("Example Prompts") |
|
|
example_prompts = [ |
|
|
"Can you explain how quantum computing works in simple terms?", |
|
|
"Write a short story about a robot learning to feel emotions.", |
|
|
"What are the main differences between Python and JavaScript?", |
|
|
"Give me some tips for improving my public speaking skills.", |
|
|
"Explain the concept of climate change to a 10-year-old." |
|
|
] |
|
|
|
|
|
st.markdown("### Try these prompts to get started:") |
|
|
for prompt in example_prompts: |
|
|
st.button(prompt) |