import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # --- Load your HASTE model --- model_name = "theguywhosucks/haste" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # automatically uses GPU if available torch_dtype=torch.float16 ) # --- Chat function --- def chat_with_haste(user_input, max_tokens, temperature, chat_history=[]): chat_history.append(f"User: {user_input}") # Prepare prompt prompt = "\n".join(chat_history) + "\nAI:" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate output = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode response response = tokenizer.decode(output[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) chat_history.append(f"AI: {response}") return chat_history, chat_history # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("## HASTE Chatbot with Adjustable Tokens & Temperature") chatbox = gr.Chatbot() with gr.Row(): user_input = gr.Textbox(placeholder="Type your message...", label="Your Message") submit_btn = gr.Button("Send") with gr.Row(): max_tokens_slider = gr.Slider(1, 500, value=100, step=1, label="Max Tokens") temp_slider = gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature") state = gr.State([]) # chat history submit_btn.click( chat_with_haste, inputs=[user_input, max_tokens_slider, temp_slider, state], outputs=[chatbox, state] ) demo.launch()