Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.utils import get_json_schema | |
| import torch | |
| # ----------------------- | |
| # Load model | |
| # ----------------------- | |
| model_name = "bhaiyahnsingh45/functiongemma-multiagent-router" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| # ----------------------- | |
| # Agents | |
| # ----------------------- | |
| def technical_support_agent(issue_type: str, priority: str) -> str: | |
| return f"π οΈ Routing to Technical Support: {issue_type} ({priority})" | |
| def billing_agent(request_type: str, urgency: str) -> str: | |
| return f"π° Routing to Billing: {request_type} ({urgency})" | |
| def product_info_agent(query_type: str, category: str) -> str: | |
| return f"π¦ Routing to Product Info: {query_type} ({category})" | |
| # Tool schemas | |
| AGENT_TOOLS = [ | |
| get_json_schema(technical_support_agent), | |
| get_json_schema(billing_agent), | |
| get_json_schema(product_info_agent) | |
| ] | |
| SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent." | |
| # ----------------------- | |
| # Core inference | |
| # ----------------------- | |
| def route_query(user_query: str): | |
| messages = [ | |
| {"role": "developer", "content": SYSTEM_MSG}, | |
| {"role": "user", "content": user_query} | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tools=AGENT_TOOLS, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| result = tokenizer.decode( | |
| outputs[0][len(inputs["input_ids"][0]):], | |
| skip_special_tokens=True | |
| ) | |
| return result | |
| # ----------------------- | |
| # Chatbot logic | |
| # ----------------------- | |
| def chat_fn(message, history): | |
| response = route_query(message) | |
| history.append((message, response)) | |
| return history, history | |
| # ----------------------- | |
| # UI | |
| # ----------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π€ Multi-Agent Router Chatbot") | |
| gr.Markdown("Ask anything about billing, product, or technical issues.") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(placeholder="Type your query here...") | |
| clear = gr.Button("Clear") | |
| msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # Launch | |
| demo.launch() |