File size: 3,576 Bytes
6031ee9
 
 
 
 
 
 
 
 
 
 
 
 
a79c1dd
6031ee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# --- Streamlit page config (must be first) ---
st.set_page_config(page_title="TinyLLaMA Chatbot", layout="centered")

# Device: CPU only
device = torch.device("cpu")

# --- Load the model and tokenizer ---
@st.cache_resource
def load_model():
    model_path = "cbt-tinyllama/cbt-tinyllama-merged"  
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Set pad token if missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.to(device)
    model.eval()
    return tokenizer, model

tokenizer, model = load_model()

# --- Custom styling for chat bubbles ---
st.markdown("""
    <style>
    .user-bubble {
        background-color: #DCF8C6;
        padding: 10px;
        border-radius: 20px;
        margin-bottom: 10px;
        width: fit-content;
        max-width: 80%;
        align-self: flex-end;
    }
    .bot-bubble {
        background-color: #F1F0F0;
        padding: 10px;
        border-radius: 20px;
        margin-bottom: 10px;
        width: fit-content;
        max-width: 80%;
        align-self: flex-start;
    }
    .chat-container {
        display: flex;
        flex-direction: column;
    }
    </style>
""", unsafe_allow_html=True)

# --- Title ---
st.title("🤖 TinyLLaMA Chatbot")
st.markdown("A conversational assistant powered by your fine-tuned TinyLLaMA model.")

# --- Initialize chat history ---
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- User input ---
user_input = st.chat_input("Type your message...")

# --- Generate response function ---
def generate_response(prompt):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

    # Trim input to max length
    max_length = model.config.max_position_embeddings
    if input_ids.size(1) > max_length:
        input_ids = input_ids[:, -max_length:]
        attention_mask = attention_mask[:, -max_length:]

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=100,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.8,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    # Remove prompt from output and stop at next user prompt if exists
    response = decoded[len(prompt):].split("User:")[0].strip()
    return response

# --- Process user input ---
if user_input:
    st.session_state.messages.append({"role": "user", "content": user_input})

    # Build full prompt from history
    prompt = ""
    for msg in st.session_state.messages:
        role = "User" if msg["role"] == "user" else "Assistant"
        prompt += f"{role}: {msg['content']}\n"
    prompt += "Assistant:"

    bot_reply = generate_response(prompt)
    st.session_state.messages.append({"role": "assistant", "content": bot_reply})

# --- Display chat ---
for msg in st.session_state.messages:
    if msg["role"] == "user":
        st.markdown(f'<div class="chat-container"><div class="user-bubble"><b>You:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)
    else:
        st.markdown(f'<div class="chat-container"><div class="bot-bubble"><b>Bot:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)