simplechatbot / app.py
sdgzero2ai's picture
Update app.py
7ec72c7 verified
raw
history blame
1.53 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the DeepSeek R1 model and tokenizer
model_name = "deepseek-ai/deepseek-r1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def chat_with_deepseek(user_input, history):
# Combine the history with the new user input
full_input = "\n".join(history + [user_input])
# Tokenize the input
inputs = tokenizer(full_input, return_tensors="pt")
# Generate a response
outputs = model.generate(inputs.input_ids, max_length=150, num_return_sequences=1)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the history with the new interaction
history.append(user_input)
history.append(response)
return response, history
# Create the Gradio interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
user_message = history[-1][0]
response, _ = chat_with_deepseek(user_message, [h[0] for h in history])
history[-1][1] = response
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# Launch the Gradio app
demo.launch()