hastedemo / app.py
theguywhosucks's picture
Update app.py
4a2c3c9 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# --- Load your HASTE model ---
model_name = "theguywhosucks/haste"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # automatically uses GPU if available
torch_dtype=torch.float16
)
# --- Chat function ---
def chat_with_haste(user_input, max_tokens, temperature, chat_history=[]):
chat_history.append(f"User: {user_input}")
# Prepare prompt
prompt = "\n".join(chat_history) + "\nAI:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(output[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
chat_history.append(f"AI: {response}")
return chat_history, chat_history
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## HASTE Chatbot with Adjustable Tokens & Temperature")
chatbox = gr.Chatbot()
with gr.Row():
user_input = gr.Textbox(placeholder="Type your message...", label="Your Message")
submit_btn = gr.Button("Send")
with gr.Row():
max_tokens_slider = gr.Slider(1, 500, value=100, step=1, label="Max Tokens")
temp_slider = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature")
state = gr.State([]) # chat history
submit_btn.click(
chat_with_haste,
inputs=[user_input, max_tokens_slider, temp_slider, state],
outputs=[chatbox, state]
)
demo.launch()