import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os from huggingface_hub import login # Page configuration st.set_page_config(page_title="Mistral Chatbot", layout="wide") # Title st.title("Chatbot with Mistral") # Device configuration device = "cuda" if torch.cuda.is_available() else "cpu" st.sidebar.info(f"Using device: {device}") # Authentication setup def setup_environment(): # Get token from Streamlit secrets or environment variable hf_token = st.secrets["HUGGINGFACE_TOKEN"] if "HUGGINGFACE_TOKEN" in st.secrets else os.getenv("HUGGINGFACE_TOKEN") if not hf_token: st.error("Please set your Hugging Face token in the secrets or environment variables") st.stop() try: login(token=hf_token) return True except Exception as e: st.error(f"Authentication failed: {str(e)}") return False # Model loading with caching @st.cache_resource def load_model(): model_name = "mistralai/Mistral-7B-v0.1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" # This will automatically handle device placement ) # Ensure model is on the correct device if device == "cuda": model = model.to(device) return tokenizer, model # Text generation function def generate_text(prompt, tokenizer, model): # Move inputs to the same device as the model inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=100, temperature=0.7, top_p=0.95, do_sample=True ) # Move outputs back to CPU for decoding outputs = outputs.cpu() return tokenizer.decode(outputs[0], skip_special_tokens=True) # Main application flow def main(): # Check authentication if not setup_environment(): return # Display device information st.sidebar.markdown("---") st.sidebar.markdown("### System Info") st.sidebar.markdown(f"Device: **{device}**") if device == "cuda": st.sidebar.markdown(f"GPU: **{torch.cuda.get_device_name(0)}**") st.sidebar.markdown(f"Memory Allocated: **{torch.cuda.memory_allocated(0)/1024**2:.2f}MB**") # Initialize session state for chat history if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Load model and tokenizer try: with st.spinner(f"Loading model on {device}..."): tokenizer, model = load_model() except Exception as e: st.error(f"Error loading model: {str(e)}") return # Chat interface user_input = st.text_input("Enter your message:", key="user_input") if st.button("Send"): if user_input: # Check for duplicates in chat history if st.session_state.chat_history and st.session_state.chat_history[-1][1].lower() == user_input.lower(): st.warning("You already asked this question. Please ask something else.") else: # Generate response with st.spinner("Generating response..."): response = generate_text(user_input, tokenizer, model) # Update chat history st.session_state.chat_history.append(("You", user_input)) st.session_state.chat_history.append(("Bot", response)) # Display chat history for role, message in st.session_state.chat_history: if role == "You": st.write(f"👤 **You:** {message}") else: st.write(f"🤖 **Bot:** {message}") if __name__ == "__main__": main()