File size: 2,610 Bytes
2511a1d
dfd7076
 
39baba6
2511a1d
dfd7076
39baba6
dfd7076
 
2511a1d
dfd7076
 
39baba6
 
 
dfd7076
 
 
 
 
 
39baba6
dfd7076
 
39baba6
dfd7076
 
39baba6
2511a1d
39baba6
dfd7076
 
 
 
 
2511a1d
dfd7076
2511a1d
dfd7076
39baba6
dfd7076
 
2511a1d
dfd7076
 
 
 
2511a1d
dfd7076
2511a1d
dfd7076
 
39baba6
dfd7076
 
 
39baba6
 
dfd7076
39baba6
 
 
dfd7076
2511a1d
dfd7076
39baba6
dfd7076
 
 
39baba6
dfd7076
 
 
39baba6
dfd7076
 
 
 
 
 
 
39baba6
 
 
2511a1d
39baba6
 
 
 
 
 
2511a1d
dfd7076
39baba6
2511a1d
39baba6
dfd7076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import get_json_schema
import torch

# -----------------------
# Load model
# -----------------------
model_name = "bhaiyahnsingh45/functiongemma-multiagent-router"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="auto"
)

# -----------------------
# Agents
# -----------------------
def technical_support_agent(issue_type: str, priority: str) -> str:
    return f"πŸ› οΈ Routing to Technical Support: {issue_type} ({priority})"

def billing_agent(request_type: str, urgency: str) -> str:
    return f"πŸ’° Routing to Billing: {request_type} ({urgency})"

def product_info_agent(query_type: str, category: str) -> str:
    return f"πŸ“¦ Routing to Product Info: {query_type} ({category})"

# Tool schemas
AGENT_TOOLS = [
    get_json_schema(technical_support_agent),
    get_json_schema(billing_agent),
    get_json_schema(product_info_agent)
]

SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent."

# -----------------------
# Core inference
# -----------------------
def route_query(user_query: str):

    messages = [
        {"role": "developer", "content": SYSTEM_MSG},
        {"role": "user", "content": user_query}
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        tools=AGENT_TOOLS,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt"
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    outputs = model.generate(
        **inputs,
        max_new_tokens=128,
        pad_token_id=tokenizer.eos_token_id
    )

    result = tokenizer.decode(
        outputs[0][len(inputs["input_ids"][0]):],
        skip_special_tokens=True
    )

    return result


# -----------------------
# Chatbot logic
# -----------------------
def chat_fn(message, history):
    response = route_query(message)
    history.append((message, response))
    return history, history


# -----------------------
# UI
# -----------------------
with gr.Blocks() as demo:
    gr.Markdown("## πŸ€– Multi-Agent Router Chatbot")
    gr.Markdown("Ask anything about billing, product, or technical issues.")

    chatbot = gr.Chatbot()
    msg = gr.Textbox(placeholder="Type your query here...")
    clear = gr.Button("Clear")

    msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

# Launch
demo.launch()