Liao
updateFunction
979cd0f
import json
import re
import operator
from typing import Literal, Dict, Any
from typing_extensions import TypedDict, Annotated
import gradio as gr
from langchain_core.tools import tool # (recommended import)
from langchain_huggingface import HuggingFacePipeline
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
from langgraph.graph import StateGraph, START, END
# 1) Model
llm = HuggingFacePipeline.from_model_id(
model_id="microsoft/Phi-3-mini-4k-instruct",
task="text-generation",
pipeline_kwargs={
"max_new_tokens": 96,
"top_k": 50,
"temperature": 0.1,
"return_full_text": False,
},
)
# 2) Tools
@tool
def multiply(a: int, b: int) -> int:
"""Multiply a and b."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add a and b."""
return a + b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b."""
return a / b
tools = [add, multiply, divide]
tools_by_name = {t.name: t for t in tools}
# 3) State
class MessagesState(TypedDict):
messages: Annotated[list[BaseMessage], operator.add]
llm_calls: int
SYSTEM = """You are an arithmetic tool user.
You must output exactly ONE object, preferably STRICT JSON.
If there is NO tool result yet, output a tool call:
{"tool": "add"|"multiply"|"divide", "args": {"a": <int>, "b": <int>}}
If there IS a tool result already, output the final answer:
{"final": "<answer>"}
No extra text. Use double quotes if possible.
""".strip()
def _format_for_phi(messages: list[BaseMessage]) -> str:
parts = [SYSTEM, ""]
for m in messages:
if isinstance(m, HumanMessage):
parts.append(f"User: {m.content}")
elif isinstance(m, ToolMessage):
parts.append(f"Tool result: {m.content}")
parts.append("Assistant:")
return "\n".join(parts)
def _hard_trim_to_first_turn(text: str) -> str:
cuts = ["\n\nUser:", "\nUser:", "\n\nAssistant:", "\nAssistant:"]
for c in cuts:
if c in text:
text = text.split(c, 1)[0]
return text.strip()
def _parse_model_output(text: str) -> Dict[str, Any]:
m = re.search(r"\{.*?\}", text, flags=re.DOTALL)
candidate = m.group(0).strip() if m else ""
if candidate:
s = candidate.replace("“", '"').replace("”", '"').replace("’", "'")
s = re.sub(r",\s*}", "}", s)
s = re.sub(r",\s*]", "]", s)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
mf = re.search(r'"final"\s*:\s*("?)([^"\}\n]+)\1', text)
if mf:
return {"final": mf.group(2).strip()}
mt = re.search(r'"tool"\s*:\s*"(?P<tool>add|multiply|divide)"', text)
tool_name = mt.group("tool") if mt else None
if tool_name is None:
mt2 = re.search(r'\btool\b\s*:\s*(add|multiply|divide)', text)
if not mt2:
raise ValueError(f"Could not parse tool/final from model output:\n{text}")
tool_name = mt2.group(1)
ma = re.search(r'"a"\s*:\s*(-?\d+)', text)
mb = re.search(r'"b"\s*:\s*(-?\d+)', text)
if not ma or not mb:
raise ValueError(f"Parsed tool={tool_name} but could not parse a/b from:\n{text}")
return {"tool": tool_name, "args": {"a": int(ma.group(1)), "b": int(mb.group(1))}}
# 4) Nodes
def llm_call(state: dict):
prompt = _format_for_phi(state["messages"])
raw = llm.invoke(prompt)
raw = _hard_trim_to_first_turn(raw)
data = _parse_model_output(raw)
msg = AIMessage(content=raw, additional_kwargs={"parsed": data})
return {"messages": [msg], "llm_calls": state.get("llm_calls", 0) + 1}
def tool_node(state: dict):
last = state["messages"][-1]
data = last.additional_kwargs.get("parsed", {})
if "tool" not in data:
return {"messages": []}
tool_name = data["tool"]
args = data["args"]
obs = tools_by_name[tool_name].invoke(args)
return {"messages": [ToolMessage(content=str(obs), tool_call_id=f"{tool_name}-call")]}
def should_continue(state: MessagesState) -> Literal["tool_node", END]:
last = state["messages"][-1]
data = last.additional_kwargs.get("parsed", {})
return "tool_node" if "tool" in data else END
# 5) Graph
agent_builder = StateGraph(MessagesState)
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("tool_node", tool_node)
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges("llm_call", should_continue, ["tool_node", END])
agent_builder.add_edge("tool_node", "llm_call")
agent = agent_builder.compile()
# 6) Web handler
def run_agent(user_text: str) -> str:
out = agent.invoke({"messages": [HumanMessage(content=user_text)], "llm_calls": 0})
# Find final JSON (the last AI message content usually contains it)
last_ai = None
for m in reversed(out["messages"]):
if isinstance(m, AIMessage):
last_ai = m
break
if last_ai is None:
return "No AI output."
parsed = last_ai.additional_kwargs.get("parsed", {})
if "final" in parsed:
return str(parsed["final"])
return last_ai.content # fallback
demo = gr.Interface(
fn=run_agent,
inputs=gr.Textbox(label="Ask an arithmetic question", placeholder="e.g., 4 divided by 3"),
outputs=gr.Textbox(label="Answer"),
title="Tool-Using Arithmetic Agent (LangGraph + Phi-3-mini)",
)
if __name__ == "__main__":
demo.launch()