Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import operator | |
| from typing import Literal, Dict, Any | |
| from typing_extensions import TypedDict, Annotated | |
| import gradio as gr | |
| from langchain_core.tools import tool # (recommended import) | |
| from langchain_huggingface import HuggingFacePipeline | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage | |
| from langgraph.graph import StateGraph, START, END | |
| # 1) Model | |
| llm = HuggingFacePipeline.from_model_id( | |
| model_id="microsoft/Phi-3-mini-4k-instruct", | |
| task="text-generation", | |
| pipeline_kwargs={ | |
| "max_new_tokens": 96, | |
| "top_k": 50, | |
| "temperature": 0.1, | |
| "return_full_text": False, | |
| }, | |
| ) | |
| # 2) Tools | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply a and b.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add a and b.""" | |
| return a + b | |
| def divide(a: int, b: int) -> float: | |
| """Divide a by b.""" | |
| return a / b | |
| tools = [add, multiply, divide] | |
| tools_by_name = {t.name: t for t in tools} | |
| # 3) State | |
| class MessagesState(TypedDict): | |
| messages: Annotated[list[BaseMessage], operator.add] | |
| llm_calls: int | |
| SYSTEM = """You are an arithmetic tool user. | |
| You must output exactly ONE object, preferably STRICT JSON. | |
| If there is NO tool result yet, output a tool call: | |
| {"tool": "add"|"multiply"|"divide", "args": {"a": <int>, "b": <int>}} | |
| If there IS a tool result already, output the final answer: | |
| {"final": "<answer>"} | |
| No extra text. Use double quotes if possible. | |
| """.strip() | |
| def _format_for_phi(messages: list[BaseMessage]) -> str: | |
| parts = [SYSTEM, ""] | |
| for m in messages: | |
| if isinstance(m, HumanMessage): | |
| parts.append(f"User: {m.content}") | |
| elif isinstance(m, ToolMessage): | |
| parts.append(f"Tool result: {m.content}") | |
| parts.append("Assistant:") | |
| return "\n".join(parts) | |
| def _hard_trim_to_first_turn(text: str) -> str: | |
| cuts = ["\n\nUser:", "\nUser:", "\n\nAssistant:", "\nAssistant:"] | |
| for c in cuts: | |
| if c in text: | |
| text = text.split(c, 1)[0] | |
| return text.strip() | |
| def _parse_model_output(text: str) -> Dict[str, Any]: | |
| m = re.search(r"\{.*?\}", text, flags=re.DOTALL) | |
| candidate = m.group(0).strip() if m else "" | |
| if candidate: | |
| s = candidate.replace("“", '"').replace("”", '"').replace("’", "'") | |
| s = re.sub(r",\s*}", "}", s) | |
| s = re.sub(r",\s*]", "]", s) | |
| try: | |
| obj = json.loads(s) | |
| if isinstance(obj, dict): | |
| return obj | |
| except json.JSONDecodeError: | |
| pass | |
| mf = re.search(r'"final"\s*:\s*("?)([^"\}\n]+)\1', text) | |
| if mf: | |
| return {"final": mf.group(2).strip()} | |
| mt = re.search(r'"tool"\s*:\s*"(?P<tool>add|multiply|divide)"', text) | |
| tool_name = mt.group("tool") if mt else None | |
| if tool_name is None: | |
| mt2 = re.search(r'\btool\b\s*:\s*(add|multiply|divide)', text) | |
| if not mt2: | |
| raise ValueError(f"Could not parse tool/final from model output:\n{text}") | |
| tool_name = mt2.group(1) | |
| ma = re.search(r'"a"\s*:\s*(-?\d+)', text) | |
| mb = re.search(r'"b"\s*:\s*(-?\d+)', text) | |
| if not ma or not mb: | |
| raise ValueError(f"Parsed tool={tool_name} but could not parse a/b from:\n{text}") | |
| return {"tool": tool_name, "args": {"a": int(ma.group(1)), "b": int(mb.group(1))}} | |
| # 4) Nodes | |
| def llm_call(state: dict): | |
| prompt = _format_for_phi(state["messages"]) | |
| raw = llm.invoke(prompt) | |
| raw = _hard_trim_to_first_turn(raw) | |
| data = _parse_model_output(raw) | |
| msg = AIMessage(content=raw, additional_kwargs={"parsed": data}) | |
| return {"messages": [msg], "llm_calls": state.get("llm_calls", 0) + 1} | |
| def tool_node(state: dict): | |
| last = state["messages"][-1] | |
| data = last.additional_kwargs.get("parsed", {}) | |
| if "tool" not in data: | |
| return {"messages": []} | |
| tool_name = data["tool"] | |
| args = data["args"] | |
| obs = tools_by_name[tool_name].invoke(args) | |
| return {"messages": [ToolMessage(content=str(obs), tool_call_id=f"{tool_name}-call")]} | |
| def should_continue(state: MessagesState) -> Literal["tool_node", END]: | |
| last = state["messages"][-1] | |
| data = last.additional_kwargs.get("parsed", {}) | |
| return "tool_node" if "tool" in data else END | |
| # 5) Graph | |
| agent_builder = StateGraph(MessagesState) | |
| agent_builder.add_node("llm_call", llm_call) | |
| agent_builder.add_node("tool_node", tool_node) | |
| agent_builder.add_edge(START, "llm_call") | |
| agent_builder.add_conditional_edges("llm_call", should_continue, ["tool_node", END]) | |
| agent_builder.add_edge("tool_node", "llm_call") | |
| agent = agent_builder.compile() | |
| # 6) Web handler | |
| def run_agent(user_text: str) -> str: | |
| out = agent.invoke({"messages": [HumanMessage(content=user_text)], "llm_calls": 0}) | |
| # Find final JSON (the last AI message content usually contains it) | |
| last_ai = None | |
| for m in reversed(out["messages"]): | |
| if isinstance(m, AIMessage): | |
| last_ai = m | |
| break | |
| if last_ai is None: | |
| return "No AI output." | |
| parsed = last_ai.additional_kwargs.get("parsed", {}) | |
| if "final" in parsed: | |
| return str(parsed["final"]) | |
| return last_ai.content # fallback | |
| demo = gr.Interface( | |
| fn=run_agent, | |
| inputs=gr.Textbox(label="Ask an arithmetic question", placeholder="e.g., 4 divided by 3"), | |
| outputs=gr.Textbox(label="Answer"), | |
| title="Tool-Using Arithmetic Agent (LangGraph + Phi-3-mini)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |