mistral / app.py
Manith Marapperuma
Update app.py
0edd432 verified
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model and tokenizer
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# Streamlit app
def main():
st.title("Mistral Chatbot")
user_input = st.text_input("You: ", "Hello, chatbot!")
if st.button("Send"):
with st.spinner("Thinking..."):
# Tokenize the user input and generate a response
model_inputs = tokenizer(user_input, return_tensors="pt")
model_inputs = {k: v.to("cuda") for k, v in model_inputs.items()}
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
chatbot_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
st.text_area("Chatbot:", value=chatbot_response, height=200, max_chars=None, key=None)
if __name__ == "__main__":
main()