import json import re import operator from typing import Literal, Dict, Any from typing_extensions import TypedDict, Annotated import gradio as gr from langchain_core.tools import tool # (recommended import) from langchain_huggingface import HuggingFacePipeline from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage from langgraph.graph import StateGraph, START, END # 1) Model llm = HuggingFacePipeline.from_model_id( model_id="microsoft/Phi-3-mini-4k-instruct", task="text-generation", pipeline_kwargs={ "max_new_tokens": 96, "top_k": 50, "temperature": 0.1, "return_full_text": False, }, ) # 2) Tools @tool def multiply(a: int, b: int) -> int: """Multiply a and b.""" return a * b @tool def add(a: int, b: int) -> int: """Add a and b.""" return a + b @tool def divide(a: int, b: int) -> float: """Divide a by b.""" return a / b tools = [add, multiply, divide] tools_by_name = {t.name: t for t in tools} # 3) State class MessagesState(TypedDict): messages: Annotated[list[BaseMessage], operator.add] llm_calls: int SYSTEM = """You are an arithmetic tool user. You must output exactly ONE object, preferably STRICT JSON. If there is NO tool result yet, output a tool call: {"tool": "add"|"multiply"|"divide", "args": {"a": , "b": }} If there IS a tool result already, output the final answer: {"final": ""} No extra text. Use double quotes if possible. """.strip() def _format_for_phi(messages: list[BaseMessage]) -> str: parts = [SYSTEM, ""] for m in messages: if isinstance(m, HumanMessage): parts.append(f"User: {m.content}") elif isinstance(m, ToolMessage): parts.append(f"Tool result: {m.content}") parts.append("Assistant:") return "\n".join(parts) def _hard_trim_to_first_turn(text: str) -> str: cuts = ["\n\nUser:", "\nUser:", "\n\nAssistant:", "\nAssistant:"] for c in cuts: if c in text: text = text.split(c, 1)[0] return text.strip() def _parse_model_output(text: str) -> Dict[str, Any]: m = re.search(r"\{.*?\}", text, flags=re.DOTALL) candidate = m.group(0).strip() if m else "" if candidate: s = candidate.replace("“", '"').replace("”", '"').replace("’", "'") s = re.sub(r",\s*}", "}", s) s = re.sub(r",\s*]", "]", s) try: obj = json.loads(s) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass mf = re.search(r'"final"\s*:\s*("?)([^"\}\n]+)\1', text) if mf: return {"final": mf.group(2).strip()} mt = re.search(r'"tool"\s*:\s*"(?Padd|multiply|divide)"', text) tool_name = mt.group("tool") if mt else None if tool_name is None: mt2 = re.search(r'\btool\b\s*:\s*(add|multiply|divide)', text) if not mt2: raise ValueError(f"Could not parse tool/final from model output:\n{text}") tool_name = mt2.group(1) ma = re.search(r'"a"\s*:\s*(-?\d+)', text) mb = re.search(r'"b"\s*:\s*(-?\d+)', text) if not ma or not mb: raise ValueError(f"Parsed tool={tool_name} but could not parse a/b from:\n{text}") return {"tool": tool_name, "args": {"a": int(ma.group(1)), "b": int(mb.group(1))}} # 4) Nodes def llm_call(state: dict): prompt = _format_for_phi(state["messages"]) raw = llm.invoke(prompt) raw = _hard_trim_to_first_turn(raw) data = _parse_model_output(raw) msg = AIMessage(content=raw, additional_kwargs={"parsed": data}) return {"messages": [msg], "llm_calls": state.get("llm_calls", 0) + 1} def tool_node(state: dict): last = state["messages"][-1] data = last.additional_kwargs.get("parsed", {}) if "tool" not in data: return {"messages": []} tool_name = data["tool"] args = data["args"] obs = tools_by_name[tool_name].invoke(args) return {"messages": [ToolMessage(content=str(obs), tool_call_id=f"{tool_name}-call")]} def should_continue(state: MessagesState) -> Literal["tool_node", END]: last = state["messages"][-1] data = last.additional_kwargs.get("parsed", {}) return "tool_node" if "tool" in data else END # 5) Graph agent_builder = StateGraph(MessagesState) agent_builder.add_node("llm_call", llm_call) agent_builder.add_node("tool_node", tool_node) agent_builder.add_edge(START, "llm_call") agent_builder.add_conditional_edges("llm_call", should_continue, ["tool_node", END]) agent_builder.add_edge("tool_node", "llm_call") agent = agent_builder.compile() # 6) Web handler def run_agent(user_text: str) -> str: out = agent.invoke({"messages": [HumanMessage(content=user_text)], "llm_calls": 0}) # Find final JSON (the last AI message content usually contains it) last_ai = None for m in reversed(out["messages"]): if isinstance(m, AIMessage): last_ai = m break if last_ai is None: return "No AI output." parsed = last_ai.additional_kwargs.get("parsed", {}) if "final" in parsed: return str(parsed["final"]) return last_ai.content # fallback demo = gr.Interface( fn=run_agent, inputs=gr.Textbox(label="Ask an arithmetic question", placeholder="e.g., 4 divided by 3"), outputs=gr.Textbox(label="Answer"), title="Tool-Using Arithmetic Agent (LangGraph + Phi-3-mini)", ) if __name__ == "__main__": demo.launch()