Patchwork-26 / app.py
theguywhosucks's picture
Update app.py
b6370c0 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# --- Load HF model ---
model_name = "theguywhosucks/haste"
hf_token = os.environ.get("identification") # Grab the secret token (optional)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token if hf_token else None
)
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
token=hf_token if hf_token else None
)
# --- Chat logic ---
def chat_fn(user_input, history):
history = history or []
# Build conversation string
conversation = ""
for pair in history:
conversation += f"User: {pair[0]}\nAssistant: {pair[1]}\n"
conversation += f"User: {user_input}\nAssistant: "
# Generate
inputs = tokenizer(conversation, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
reply = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract only assistant’s last reply
if "Assistant:" in reply:
reply = reply.split("Assistant:")[-1].strip()
# Update chat history
history.append([user_input, reply])
return history, "" # clear textbox
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("<h2 style='text-align:center;'>🤖 HASTE Chatbot</h2>")
chatbot = gr.Chatbot(height=600)
with gr.Row():
user_input = gr.Textbox(placeholder="Type a message...", show_label=False, lines=1)
send_btn = gr.Button("Send")
send_btn.click(chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
user_input.submit(chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input])
if __name__ == "__main__":
demo.launch()