JARVIS / app.py
Krish45's picture
Update app.py
6d5e95a verified
raw
history blame
1.78 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
)
def predict(history, message):
"""
history: list of [user, bot] message pairs from the Chatbot
message: new user input string
"""
# Add the latest user message to the conversation
history = history or [] # make sure it's a list
history.append((message, ""))
# Convert to messages format for Qwen
messages = []
for human, bot in history:
if human:
messages.append({"role": "user", "content": human})
if bot:
messages.append({"role": "assistant", "content": bot})
# Apply chat template
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate response
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Update last message with bot reply
history[-1] = (message, reply)
return history, "" # return history + clear textbox
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Type your message here...")
msg.submit(predict, [chatbot, msg], [chatbot, msg])
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)