import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.utils import get_json_schema import torch # ----------------------- # Load model # ----------------------- model_name = "bhaiyahnsingh45/functiongemma-multiagent-router" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype="auto" ) # ----------------------- # Agents # ----------------------- def technical_support_agent(issue_type: str, priority: str) -> str: return f"🛠️ Routing to Technical Support: {issue_type} ({priority})" def billing_agent(request_type: str, urgency: str) -> str: return f"💰 Routing to Billing: {request_type} ({urgency})" def product_info_agent(query_type: str, category: str) -> str: return f"📦 Routing to Product Info: {query_type} ({category})" # Tool schemas AGENT_TOOLS = [ get_json_schema(technical_support_agent), get_json_schema(billing_agent), get_json_schema(product_info_agent) ] SYSTEM_MSG = "You are an intelligent routing agent that directs customer queries to the appropriate specialized agent." # ----------------------- # Core inference # ----------------------- def route_query(user_query: str): messages = [ {"role": "developer", "content": SYSTEM_MSG}, {"role": "user", "content": user_query} ] inputs = tokenizer.apply_chat_template( messages, tools=AGENT_TOOLS, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) inputs = {k: v.to(model.device) for k, v in inputs.items()} outputs = model.generate( **inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id ) result = tokenizer.decode( outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True ) return result # ----------------------- # Chatbot logic # ----------------------- def chat_fn(message, history): response = route_query(message) history.append((message, response)) return history, history # ----------------------- # UI # ----------------------- with gr.Blocks() as demo: gr.Markdown("## 🤖 Multi-Agent Router Chatbot") gr.Markdown("Ask anything about billing, product, or technical issues.") chatbot = gr.Chatbot() msg = gr.Textbox(placeholder="Type your query here...") clear = gr.Button("Clear") msg.submit(chat_fn, [msg, chatbot], [chatbot, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) # Launch demo.launch()