HealthBot / chatbot.py
rashmiprajapati's picture
model path changed
a79c1dd
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# --- Streamlit page config (must be first) ---
st.set_page_config(page_title="TinyLLaMA Chatbot", layout="centered")
# Device: CPU only
device = torch.device("cpu")
# --- Load the model and tokenizer ---
@st.cache_resource
def load_model():
model_path = "cbt-tinyllama/cbt-tinyllama-merged"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Set pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
# --- Custom styling for chat bubbles ---
st.markdown("""
<style>
.user-bubble {
background-color: #DCF8C6;
padding: 10px;
border-radius: 20px;
margin-bottom: 10px;
width: fit-content;
max-width: 80%;
align-self: flex-end;
}
.bot-bubble {
background-color: #F1F0F0;
padding: 10px;
border-radius: 20px;
margin-bottom: 10px;
width: fit-content;
max-width: 80%;
align-self: flex-start;
}
.chat-container {
display: flex;
flex-direction: column;
}
</style>
""", unsafe_allow_html=True)
# --- Title ---
st.title("🤖 TinyLLaMA Chatbot")
st.markdown("A conversational assistant powered by your fine-tuned TinyLLaMA model.")
# --- Initialize chat history ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- User input ---
user_input = st.chat_input("Type your message...")
# --- Generate response function ---
def generate_response(prompt):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
# Trim input to max length
max_length = model.config.max_position_embeddings
if input_ids.size(1) > max_length:
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
with torch.no_grad():
output_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Remove prompt from output and stop at next user prompt if exists
response = decoded[len(prompt):].split("User:")[0].strip()
return response
# --- Process user input ---
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
# Build full prompt from history
prompt = ""
for msg in st.session_state.messages:
role = "User" if msg["role"] == "user" else "Assistant"
prompt += f"{role}: {msg['content']}\n"
prompt += "Assistant:"
bot_reply = generate_response(prompt)
st.session_state.messages.append({"role": "assistant", "content": bot_reply})
# --- Display chat ---
for msg in st.session_state.messages:
if msg["role"] == "user":
st.markdown(f'<div class="chat-container"><div class="user-bubble"><b>You:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)
else:
st.markdown(f'<div class="chat-container"><div class="bot-bubble"><b>Bot:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True)