import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch # --- Streamlit page config (must be first) --- st.set_page_config(page_title="TinyLLaMA Chatbot", layout="centered") # Device: CPU only device = torch.device("cpu") # --- Load the model and tokenizer --- @st.cache_resource def load_model(): model_path = "cbt-tinyllama/cbt-tinyllama-merged" tokenizer = AutoTokenizer.from_pretrained(model_path) # Set pad token if missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_path) model.to(device) model.eval() return tokenizer, model tokenizer, model = load_model() # --- Custom styling for chat bubbles --- st.markdown(""" """, unsafe_allow_html=True) # --- Title --- st.title("🤖 TinyLLaMA Chatbot") st.markdown("A conversational assistant powered by your fine-tuned TinyLLaMA model.") # --- Initialize chat history --- if "messages" not in st.session_state: st.session_state.messages = [] # --- User input --- user_input = st.chat_input("Type your message...") # --- Generate response function --- def generate_response(prompt): input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device) # Trim input to max length max_length = model.config.max_position_embeddings if input_ids.size(1) > max_length: input_ids = input_ids[:, -max_length:] attention_mask = attention_mask[:, -max_length:] with torch.no_grad(): output_ids = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95, temperature=0.8, pad_token_id=tokenizer.eos_token_id ) decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Remove prompt from output and stop at next user prompt if exists response = decoded[len(prompt):].split("User:")[0].strip() return response # --- Process user input --- if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) # Build full prompt from history prompt = "" for msg in st.session_state.messages: role = "User" if msg["role"] == "user" else "Assistant" prompt += f"{role}: {msg['content']}\n" prompt += "Assistant:" bot_reply = generate_response(prompt) st.session_state.messages.append({"role": "assistant", "content": bot_reply}) # --- Display chat --- for msg in st.session_state.messages: if msg["role"] == "user": st.markdown(f'