| import streamlit as st |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| |
| st.set_page_config(page_title="TinyLLaMA Chatbot", layout="centered") |
|
|
| |
| device = torch.device("cpu") |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| model_path = "cbt-tinyllama/cbt-tinyllama-merged" |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = AutoModelForCausalLM.from_pretrained(model_path) |
| model.to(device) |
| model.eval() |
| return tokenizer, model |
|
|
| tokenizer, model = load_model() |
|
|
| |
| st.markdown(""" |
| <style> |
| .user-bubble { |
| background-color: #DCF8C6; |
| padding: 10px; |
| border-radius: 20px; |
| margin-bottom: 10px; |
| width: fit-content; |
| max-width: 80%; |
| align-self: flex-end; |
| } |
| .bot-bubble { |
| background-color: #F1F0F0; |
| padding: 10px; |
| border-radius: 20px; |
| margin-bottom: 10px; |
| width: fit-content; |
| max-width: 80%; |
| align-self: flex-start; |
| } |
| .chat-container { |
| display: flex; |
| flex-direction: column; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.title("🤖 TinyLLaMA Chatbot") |
| st.markdown("A conversational assistant powered by your fine-tuned TinyLLaMA model.") |
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| |
| user_input = st.chat_input("Type your message...") |
|
|
| |
| def generate_response(prompt): |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device) |
|
|
| |
| max_length = model.config.max_position_embeddings |
| if input_ids.size(1) > max_length: |
| input_ids = input_ids[:, -max_length:] |
| attention_mask = attention_mask[:, -max_length:] |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=100, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95, |
| temperature=0.8, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
|
|
| decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| |
| response = decoded[len(prompt):].split("User:")[0].strip() |
| return response |
|
|
| |
| if user_input: |
| st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
| |
| prompt = "" |
| for msg in st.session_state.messages: |
| role = "User" if msg["role"] == "user" else "Assistant" |
| prompt += f"{role}: {msg['content']}\n" |
| prompt += "Assistant:" |
|
|
| bot_reply = generate_response(prompt) |
| st.session_state.messages.append({"role": "assistant", "content": bot_reply}) |
|
|
| |
| for msg in st.session_state.messages: |
| if msg["role"] == "user": |
| st.markdown(f'<div class="chat-container"><div class="user-bubble"><b>You:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True) |
| else: |
| st.markdown(f'<div class="chat-container"><div class="bot-bubble"><b>Bot:</b><br>{msg["content"]}</div></div>', unsafe_allow_html=True) |
|
|
|
|
|
|