IPTchatbot / chatbot.py
ammoncoder123's picture
Update chatbot.py
fb13fa4 verified
raw
history blame
2.88 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch
# ================= CACHE THE MODEL =================
@st.cache_resource
def load_model():
model_id = "ammoncoder123/IPTchatbotModel1-1.7B" # ← Your correct model repo
# 4-bit quantization for memory efficiency (required for 1.7B on GPU)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto", # Automatically uses GPU if available
torch_dtype=torch.float16,
trust_remote_code=True # Sometimes needed for custom models
)
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
top_p=0.9
)
# Load model once (this will run on first use)
pipe = load_model()
# ==================== CHAT INTERFACE ====================
st.title("IPT Chatbot (1.7B Fine-Tuned Model)")
# Show a disclaimer
st.info("⚠️ This is a small fine-tuned model (1.7B parameters). Answers may contain inaccuracies. Always verify important information.")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Ask me about IPT, ICT, or anything else..."):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
# Use proper chat format for Instruct models
chat_messages = [
{"role": "user", "content": prompt}
]
outputs = pipe(
chat_messages,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
top_p=0.9
)
# Extract generated text
response = outputs[0]["generated_text"]
# Clean up echoed prompt
if isinstance(response, str) and response.startswith(prompt):
response = response[len(prompt):].strip()
st.markdown(response)
# Save assistant response
st.session_state.messages.append({"role": "assistant", "content": response})
# Optional: Clear chat button
if st.button("Clear Conversation"):
st.session_state.messages = []
st.rerun()