File size: 3,466 Bytes
c2567b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9867c26
c2567b3
 
 
 
 
 
 
9867c26
 
 
 
 
 
 
 
 
 
 
 
 
51f6238
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Page config
st.set_page_config(
    page_title="Phi-2 QLoRA Chatbot",
    page_icon="🤖",
    layout="wide"
)

# Initialize session state for chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

@st.cache_resource
def load_model():
    # Load base model and tokenizer
    base_model = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    
    # Load the LoRA adapter
    model = PeftModel.from_pretrained(
        base_model,
        "phi2-qlora-output",
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        "microsoft/phi-2",
        trust_remote_code=True
    )
    
    return model, tokenizer

def generate_response(prompt, model, tokenizer, max_length=512, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the input prompt from the response
    response = response[len(prompt):].strip()
    return response

# Load the model
try:
    model, tokenizer = load_model()
    st.success("Model loaded successfully!")
except Exception as e:
    st.error(f"Error loading model: {str(e)}")
    st.stop()

# Chat interface
st.title("Phi-2 QLoRA Chatbot 🤖")

# Display chat messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

# Chat input
if prompt := st.chat_input():
    # Display user message
    with st.chat_message("user"):
        st.write(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    # Generate and display assistant response
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = generate_response(prompt, model, tokenizer)
            st.write(response)
    st.session_state.messages.append({"role": "assistant", "content": response})

# Sidebar with model information and example prompts
with st.sidebar:
    st.title("About")
    st.markdown("""
    This chatbot uses a fine-tuned version of the Microsoft Phi-2 model,
    trained using the QLoRA technique. The model has been optimized for
    specific conversational tasks while maintaining efficiency through
    parameter-efficient fine-tuning.
    """)
    
    st.title("Example Prompts")
    example_prompts = [
        "Can you explain how quantum computing works in simple terms?",
        "Write a short story about a robot learning to feel emotions.",
        "What are the main differences between Python and JavaScript?",
        "Give me some tips for improving my public speaking skills.",
        "Explain the concept of climate change to a 10-year-old."
    ]
    
    st.markdown("### Try these prompts to get started:")
    for prompt in example_prompts:
        st.button(prompt)