|
|
|
|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
model_name = "mistralai/Mistral-7B-v0.1" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("Mistral Chatbot") |
|
|
|
|
|
user_input = st.text_input("You: ", "Hello, chatbot!") |
|
|
|
|
|
if st.button("Send"): |
|
|
with st.spinner("Thinking..."): |
|
|
|
|
|
model_inputs = tokenizer(user_input, return_tensors="pt") |
|
|
model_inputs = {k: v.to("cuda") for k, v in model_inputs.items()} |
|
|
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) |
|
|
chatbot_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
st.text_area("Chatbot:", value=chatbot_response, height=200, max_chars=None, key=None) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|