Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| # Page configuration | |
| st.set_page_config(page_title="Mistral Chatbot", layout="wide") | |
| # Title | |
| st.title("Chatbot with Mistral") | |
| # Device configuration | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.sidebar.info(f"Using device: {device}") | |
| # Authentication setup | |
| def setup_environment(): | |
| # Get token from Streamlit secrets or environment variable | |
| hf_token = st.secrets["HUGGINGFACE_TOKEN"] if "HUGGINGFACE_TOKEN" in st.secrets else os.getenv("HUGGINGFACE_TOKEN") | |
| if not hf_token: | |
| st.error("Please set your Hugging Face token in the secrets or environment variables") | |
| st.stop() | |
| try: | |
| login(token=hf_token) | |
| return True | |
| except Exception as e: | |
| st.error(f"Authentication failed: {str(e)}") | |
| return False | |
| # Model loading with caching | |
| def load_model(): | |
| model_name = "mistralai/Mistral-7B-v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" # This will automatically handle device placement | |
| ) | |
| # Ensure model is on the correct device | |
| if device == "cuda": | |
| model = model.to(device) | |
| return tokenizer, model | |
| # Text generation function | |
| def generate_text(prompt, tokenizer, model): | |
| # Move inputs to the same device as the model | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True | |
| ) | |
| # Move outputs back to CPU for decoding | |
| outputs = outputs.cpu() | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Main application flow | |
| def main(): | |
| # Check authentication | |
| if not setup_environment(): | |
| return | |
| # Display device information | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### System Info") | |
| st.sidebar.markdown(f"Device: **{device}**") | |
| if device == "cuda": | |
| st.sidebar.markdown(f"GPU: **{torch.cuda.get_device_name(0)}**") | |
| st.sidebar.markdown(f"Memory Allocated: **{torch.cuda.memory_allocated(0)/1024**2:.2f}MB**") | |
| # Initialize session state for chat history | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Load model and tokenizer | |
| try: | |
| with st.spinner(f"Loading model on {device}..."): | |
| tokenizer, model = load_model() | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return | |
| # Chat interface | |
| user_input = st.text_input("Enter your message:", key="user_input") | |
| if st.button("Send"): | |
| if user_input: | |
| # Check for duplicates in chat history | |
| if st.session_state.chat_history and st.session_state.chat_history[-1][1].lower() == user_input.lower(): | |
| st.warning("You already asked this question. Please ask something else.") | |
| else: | |
| # Generate response | |
| with st.spinner("Generating response..."): | |
| response = generate_text(user_input, tokenizer, model) | |
| # Update chat history | |
| st.session_state.chat_history.append(("You", user_input)) | |
| st.session_state.chat_history.append(("Bot", response)) | |
| # Display chat history | |
| for role, message in st.session_state.chat_history: | |
| if role == "You": | |
| st.write(f"👤 **You:** {message}") | |
| else: | |
| st.write(f"🤖 **Bot:** {message}") | |
| if __name__ == "__main__": | |
| main() | |