File size: 3,576 Bytes
6031ee9 a79c1dd 6031ee9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# --- Streamlit page config (must be first) ---
st.set_page_config(page_title="TinyLLaMA Chatbot", layout="centered")
# Device: CPU only
device = torch.device("cpu")
# --- Load the model and tokenizer ---
@st.cache_resource
def load_model():
model_path = "cbt-tinyllama/cbt-tinyllama-merged"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Set pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# --- Custom styling for chat bubbles ---
st.markdown("""
<style>
.user-bubble {
background-color: #DCF8C6;
padding: 10px;
border-radius: 20px;
margin-bottom: 10px;
width: fit-content;
max-width: 80%;
align-self: flex-end;
}
.bot-bubble {
background-color: #F1F0F0;
padding: 10px;
border-radius: 20px;
margin-bottom: 10px;
width: fit-content;
max-width: 80%;
align-self: flex-start;
}
.chat-container {
display: flex;
flex-direction: column;
}
</style>
""", unsafe_allow_html=True)
# --- Title ---
st.title("🤖 TinyLLaMA Chatbot")
st.markdown("A conversational assistant powered by your fine-tuned TinyLLaMA model.")
# --- Initialize chat history ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- User input ---
user_input = st.chat_input("Type your message...")
# --- Generate response function ---
def generate_response(prompt):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
# Trim input to max length
max_length = model.config.max_position_embeddings
if input_ids.size(1) > max_length:
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
with torch.no_grad():
output_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Remove prompt from output and stop at next user prompt if exists
response = decoded[len(prompt):].split("User:")[0].strip()
return response
# --- Process user input ---
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
# Build full prompt from history
prompt = ""
for msg in st.session_state.messages:
role = "User" if msg["role"] == "user" else "Assistant"
prompt += f"{role}: {msg['content']}\n"
prompt += "Assistant:"
bot_reply = generate_response(prompt)
st.session_state.messages.append({"role": "assistant", "content": bot_reply})
# --- Display chat ---
for msg in st.session_state.messages:
if msg["role"] == "user":
st.markdown(f'<div class="chat-container"><div class="user-bubble"><b>You:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)
else:
st.markdown(f'<div class="chat-container"><div class="bot-bubble"><b>Bot:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)
|