File size: 4,034 Bytes
30ea0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac2eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30ea0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from huggingface_hub import login

# Page configuration
st.set_page_config(page_title="Mistral Chatbot", layout="wide")

# Title
st.title("Chatbot with Mistral")

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.info(f"Using device: {device}")

# Authentication setup
def setup_environment():
    # Get token from Streamlit secrets or environment variable
    hf_token = st.secrets["HUGGINGFACE_TOKEN"] if "HUGGINGFACE_TOKEN" in st.secrets else os.getenv("HUGGINGFACE_TOKEN")
    
    if not hf_token:
        st.error("Please set your Hugging Face token in the secrets or environment variables")
        st.stop()
    
    try:
        login(token=hf_token)
        return True
    except Exception as e:
        st.error(f"Authentication failed: {str(e)}")
        return False

# Model loading with caching
@st.cache_resource
def load_model():
    model_name = "mistralai/Mistral-7B-v0.1"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto"  # This will automatically handle device placement
    )
    
    # Ensure model is on the correct device
    if device == "cuda":
        model = model.to(device)
    
    return tokenizer, model

# Text generation function
def generate_text(prompt, tokenizer, model):
    # Move inputs to the same device as the model
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.95,
            do_sample=True
        )
    
    # Move outputs back to CPU for decoding
    outputs = outputs.cpu()
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Main application flow
def main():
    # Check authentication
    if not setup_environment():
        return
    
    # Display device information
    st.sidebar.markdown("---")
    st.sidebar.markdown("### System Info")
    st.sidebar.markdown(f"Device: **{device}**")
    if device == "cuda":
        st.sidebar.markdown(f"GPU: **{torch.cuda.get_device_name(0)}**")
        st.sidebar.markdown(f"Memory Allocated: **{torch.cuda.memory_allocated(0)/1024**2:.2f}MB**")
    
    # Initialize session state for chat history
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    
    # Load model and tokenizer
    try:
        with st.spinner(f"Loading model on {device}..."):
            tokenizer, model = load_model()
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return
    
    # Chat interface
    user_input = st.text_input("Enter your message:", key="user_input")
    
    if st.button("Send"):
        if user_input:
            # Check for duplicates in chat history
            if st.session_state.chat_history and st.session_state.chat_history[-1][1].lower() == user_input.lower():
                st.warning("You already asked this question. Please ask something else.")
            else:
                # Generate response
                with st.spinner("Generating response..."):
                    response = generate_text(user_input, tokenizer, model)
                
                # Update chat history
                st.session_state.chat_history.append(("You", user_input))
                st.session_state.chat_history.append(("Bot", response))
    
    # Display chat history
    for role, message in st.session_state.chat_history:
        if role == "You":
            st.write(f"👤 **You:** {message}")
        else:
            st.write(f"🤖 **Bot:** {message}")

if __name__ == "__main__":
    main()