| | import os |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import gradio as gr |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| |
|
| | |
| | model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | use_auth_token=HF_TOKEN, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" |
| | ) |
| |
|
| | |
| | def build_prompt(user_input, history): |
| | prompt = "You are a pirate chatbot who always responds in pirate speak!\n" |
| | for user_msg, bot_reply in history: |
| | prompt += f"User: {user_msg}\nPirate: {bot_reply}\n" |
| | prompt += f"User: {user_input}\nPirate:" |
| | return prompt |
| |
|
| | |
| | def chat(user_input, history): |
| | prompt = build_prompt(user_input, history) |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=256, |
| | do_sample=True, |
| | temperature=0.8, |
| | top_p=0.9, |
| | pad_token_id=tokenizer.eos_token_id |
| | ) |
| |
|
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | pirate_reply = response.split("Pirate:")[-1].strip() |
| | return pirate_reply |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## π΄ββ οΈ Talk to the Pirate Bot!") |
| | chatbot = gr.Chatbot() |
| | msg = gr.Textbox(placeholder="Ask the pirate something...", label="Your Message") |
| | clear = gr.Button("Clear Conversation") |
| | history = gr.State([]) |
| |
|
| | def respond(user_input, history): |
| | response = chat(user_input, history) |
| | history.append((user_input, response)) |
| | return history, history |
| |
|
| | msg.submit(respond, [msg, history], [chatbot, history]) |
| | clear.click(lambda: ([], []), None, [chatbot, history]) |
| |
|
| | demo.launch() |
| |
|