File size: 2,670 Bytes
a35c379
31ab657
 
a35c379
31ab657
 
 
 
 
 
 
 
a35c379
31ab657
a35c379
31ab657
 
 
 
 
 
 
 
 
 
a35c379
31ab657
 
a35c379
31ab657
 
 
a35c379
31ab657
 
a35c379
31ab657
 
 
 
 
a35c379
31ab657
 
a35c379
31ab657
 
 
 
 
 
 
 
 
 
 
a35c379
31ab657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a35c379
31ab657
 
 
a35c379
31ab657
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

# Load model & tokenizer
model_id = "sajeewa/empathy-chat-gemma"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

MAX_TOKENS = 2048

# System prompt
system_prompt = {
    "role": "system",
    "content": (
        "You are an empathetic AI and your friend. Always give lovely caring messages. "
        "Understand the user's feelings, then provide a caring response. "
        "Talk like a sweet friend using words like 'baby', 'cutie', etc. "
        "Use emojis when helpful. Try to continue the conversation in a gentle, emotional tone."
    )
}

# Initialize chat history
chat_history = [system_prompt]

# Define a function to generate responses
def respond(user_input, history):
    global chat_history

    # Add user message
    chat_history.append({"role": "user", "content": user_input})

    # Token length control
    chat_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False)
    while len(tokenizer(chat_prompt).input_ids) > MAX_TOKENS:
        chat_history.pop(1)  # Remove oldest non-system message
        chat_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False)

    # Prepare model input
    inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)

    # Generate response
    output = model.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.7,
        top_p=0.95,
        top_k=50,
        do_sample=True,
    )
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)
    new_response = response_text[len(chat_prompt):].strip()

    # Add assistant's response to history
    chat_history.append({"role": "assistant", "content": new_response})

    # Show full conversation
    history.append((user_input, new_response))
    return history, history

# Define reset function
def reset_chat():
    global chat_history
    chat_history = [system_prompt]
    return [], []

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 💬 Empathy Chat with Gemma")
    chatbot = gr.Chatbot()
    with gr.Row():
        msg = gr.Textbox(label="Your Message", placeholder="Tell me how you feel...")
    with gr.Row():
        send = gr.Button("Send")
        clear = gr.Button("Reset Chat")

    send.click(fn=respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot])
    clear.click(fn=reset_chat, outputs=[chatbot, chatbot])
    msg.submit(fn=respond, inputs=[msg, chatbot], outputs=[chatbot, chatbot])

# Launch the app
demo.launch()