YAMITEK's picture
Update app.py
bac2eb3 verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from huggingface_hub import login
# Page configuration
st.set_page_config(page_title="Mistral Chatbot", layout="wide")
# Title
st.title("Chatbot with Mistral")
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.info(f"Using device: {device}")
# Authentication setup
def setup_environment():
# Get token from Streamlit secrets or environment variable
hf_token = st.secrets["HUGGINGFACE_TOKEN"] if "HUGGINGFACE_TOKEN" in st.secrets else os.getenv("HUGGINGFACE_TOKEN")
if not hf_token:
st.error("Please set your Hugging Face token in the secrets or environment variables")
st.stop()
try:
login(token=hf_token)
return True
except Exception as e:
st.error(f"Authentication failed: {str(e)}")
return False
# Model loading with caching
@st.cache_resource
def load_model():
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto" # This will automatically handle device placement
)
# Ensure model is on the correct device
if device == "cuda":
model = model.to(device)
return tokenizer, model
# Text generation function
def generate_text(prompt, tokenizer, model):
# Move inputs to the same device as the model
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
temperature=0.7,
top_p=0.95,
do_sample=True
)
# Move outputs back to CPU for decoding
outputs = outputs.cpu()
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Main application flow
def main():
# Check authentication
if not setup_environment():
return
# Display device information
st.sidebar.markdown("---")
st.sidebar.markdown("### System Info")
st.sidebar.markdown(f"Device: **{device}**")
if device == "cuda":
st.sidebar.markdown(f"GPU: **{torch.cuda.get_device_name(0)}**")
st.sidebar.markdown(f"Memory Allocated: **{torch.cuda.memory_allocated(0)/1024**2:.2f}MB**")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Load model and tokenizer
try:
with st.spinner(f"Loading model on {device}..."):
tokenizer, model = load_model()
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Chat interface
user_input = st.text_input("Enter your message:", key="user_input")
if st.button("Send"):
if user_input:
# Check for duplicates in chat history
if st.session_state.chat_history and st.session_state.chat_history[-1][1].lower() == user_input.lower():
st.warning("You already asked this question. Please ask something else.")
else:
# Generate response
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model)
# Update chat history
st.session_state.chat_history.append(("You", user_input))
st.session_state.chat_history.append(("Bot", response))
# Display chat history
for role, message in st.session_state.chat_history:
if role == "You":
st.write(f"👤 **You:** {message}")
else:
st.write(f"🤖 **Bot:** {message}")
if __name__ == "__main__":
main()