Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| # --- Load HF model --- | |
| model_name = "theguywhosucks/haste" | |
| hf_token = os.environ.get("identification") # Grab the secret token (optional) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=hf_token if hf_token else None | |
| ) | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| token=hf_token if hf_token else None | |
| ) | |
| # --- Chat logic --- | |
| def chat_fn(user_input, history): | |
| history = history or [] | |
| # Build conversation string | |
| conversation = "" | |
| for pair in history: | |
| conversation += f"User: {pair[0]}\nAssistant: {pair[1]}\n" | |
| conversation += f"User: {user_input}\nAssistant: " | |
| # Generate | |
| inputs = tokenizer(conversation, return_tensors="pt").to(model.device) | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| reply = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Extract only assistant’s last reply | |
| if "Assistant:" in reply: | |
| reply = reply.split("Assistant:")[-1].strip() | |
| # Update chat history | |
| history.append([user_input, reply]) | |
| return history, "" # clear textbox | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h2 style='text-align:center;'>🤖 HASTE Chatbot</h2>") | |
| chatbot = gr.Chatbot(height=600) | |
| with gr.Row(): | |
| user_input = gr.Textbox(placeholder="Type a message...", show_label=False, lines=1) | |
| send_btn = gr.Button("Send") | |
| send_btn.click(chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input]) | |
| user_input.submit(chat_fn, inputs=[user_input, chatbot], outputs=[chatbot, user_input]) | |
| if __name__ == "__main__": | |
| demo.launch() | |