File size: 2,110 Bytes
ce88406
5fad509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_name = "facebook/blenderbot-400M-distill"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

system_preamble = (
    "You are a friendly, energetic motivational coach. "
    "Keep answers concise, positive, and actionable. "
    "When the user asks for exercises or steps, provide a short numbered list. "
)

def generate_response(history, user_message):
    conversation = ""
    for u, r in history:
        conversation += "User: " + u + "\nCoach: " + r + "\n"
    conversation += "User: " + user_message + "\nCoach:"
    prompt = system_preamble + "\n" + conversation
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
    out = model.generate(**inputs, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.8)
    reply = tokenizer.decode(out[0], skip_special_tokens=True).strip()
    return reply

def chat(user_message, chat_history):
    if chat_history is None:
        chat_history = []
    reply = generate_response(chat_history, user_message)
    chat_history.append((user_message, reply))
    return chat_history, chat_history

with gr.Blocks(title="Motivational Coach") as demo:
    gr.Markdown("<h2 style='text-align:center'>Motivational Coach — powered by BlenderBot</h2>")
    with gr.Row():
        chatbot = gr.Chatbot(elem_id="chatbot", label="Coach")
        with gr.Column(scale=0.3):
            clear_btn = gr.Button("Clear")
    msg = gr.Textbox(placeholder="Write your message here...", show_label=False)
    state = gr.State([])

    def submit_message(message, state):
        new_history, state_out = chat(message, state)
        return "", new_history, state_out

    msg.submit(submit_message, inputs=[msg, state], outputs=[msg, chatbot, state])
    clear_btn.click(lambda: ([], []), None, [chatbot, state])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)