Spaces:
Build error
Build error
File size: 4,034 Bytes
30ea0a8 bac2eb3 30ea0a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from huggingface_hub import login
# Page configuration
st.set_page_config(page_title="Mistral Chatbot", layout="wide")
# Title
st.title("Chatbot with Mistral")
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.info(f"Using device: {device}")
# Authentication setup
def setup_environment():
# Get token from Streamlit secrets or environment variable
hf_token = st.secrets["HUGGINGFACE_TOKEN"] if "HUGGINGFACE_TOKEN" in st.secrets else os.getenv("HUGGINGFACE_TOKEN")
if not hf_token:
st.error("Please set your Hugging Face token in the secrets or environment variables")
st.stop()
try:
login(token=hf_token)
return True
except Exception as e:
st.error(f"Authentication failed: {str(e)}")
return False
# Model loading with caching
@st.cache_resource
def load_model():
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto" # This will automatically handle device placement
)
# Ensure model is on the correct device
if device == "cuda":
model = model.to(device)
return tokenizer, model
# Text generation function
def generate_text(prompt, tokenizer, model):
# Move inputs to the same device as the model
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=100,
temperature=0.7,
top_p=0.95,
do_sample=True
)
# Move outputs back to CPU for decoding
outputs = outputs.cpu()
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Main application flow
def main():
# Check authentication
if not setup_environment():
return
# Display device information
st.sidebar.markdown("---")
st.sidebar.markdown("### System Info")
st.sidebar.markdown(f"Device: **{device}**")
if device == "cuda":
st.sidebar.markdown(f"GPU: **{torch.cuda.get_device_name(0)}**")
st.sidebar.markdown(f"Memory Allocated: **{torch.cuda.memory_allocated(0)/1024**2:.2f}MB**")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Load model and tokenizer
try:
with st.spinner(f"Loading model on {device}..."):
tokenizer, model = load_model()
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Chat interface
user_input = st.text_input("Enter your message:", key="user_input")
if st.button("Send"):
if user_input:
# Check for duplicates in chat history
if st.session_state.chat_history and st.session_state.chat_history[-1][1].lower() == user_input.lower():
st.warning("You already asked this question. Please ask something else.")
else:
# Generate response
with st.spinner("Generating response..."):
response = generate_text(user_input, tokenizer, model)
# Update chat history
st.session_state.chat_history.append(("You", user_input))
st.session_state.chat_history.append(("Bot", response))
# Display chat history
for role, message in st.session_state.chat_history:
if role == "You":
st.write(f"👤 **You:** {message}")
else:
st.write(f"🤖 **Bot:** {message}")
if __name__ == "__main__":
main()
|