File size: 5,473 Bytes
8865795
 
 
 
 
 
 
979cd0f
8865795
3d29713
8865795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979cd0f
8865795
 
 
 
979cd0f
8865795
 
 
 
979cd0f
8865795
 
 
 
 
 
 
3d29713
8865795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d29713
8865795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import json
import re
import operator
from typing import Literal, Dict, Any
from typing_extensions import TypedDict, Annotated

import gradio as gr
from langchain_core.tools import tool   # (recommended import)
from langchain_huggingface import HuggingFacePipeline
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
from langgraph.graph import StateGraph, START, END

# 1) Model
llm = HuggingFacePipeline.from_model_id(
    model_id="microsoft/Phi-3-mini-4k-instruct",
    task="text-generation",
    pipeline_kwargs={
        "max_new_tokens": 96,
        "top_k": 50,
        "temperature": 0.1,
        "return_full_text": False,
    },
)

# 2) Tools
@tool
def multiply(a: int, b: int) -> int:
    """Multiply a and b."""
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add a and b."""
    return a + b

@tool
def divide(a: int, b: int) -> float:
    """Divide a by b."""
    return a / b

tools = [add, multiply, divide]
tools_by_name = {t.name: t for t in tools}

# 3) State
class MessagesState(TypedDict):
    messages: Annotated[list[BaseMessage], operator.add]
    llm_calls: int

SYSTEM = """You are an arithmetic tool user.

You must output exactly ONE object, preferably STRICT JSON.

If there is NO tool result yet, output a tool call:
{"tool": "add"|"multiply"|"divide", "args": {"a": <int>, "b": <int>}}

If there IS a tool result already, output the final answer:
{"final": "<answer>"}

No extra text. Use double quotes if possible.
""".strip()

def _format_for_phi(messages: list[BaseMessage]) -> str:
    parts = [SYSTEM, ""]
    for m in messages:
        if isinstance(m, HumanMessage):
            parts.append(f"User: {m.content}")
        elif isinstance(m, ToolMessage):
            parts.append(f"Tool result: {m.content}")
    parts.append("Assistant:")
    return "\n".join(parts)

def _hard_trim_to_first_turn(text: str) -> str:
    cuts = ["\n\nUser:", "\nUser:", "\n\nAssistant:", "\nAssistant:"]
    for c in cuts:
        if c in text:
            text = text.split(c, 1)[0]
    return text.strip()

def _parse_model_output(text: str) -> Dict[str, Any]:
    m = re.search(r"\{.*?\}", text, flags=re.DOTALL)
    candidate = m.group(0).strip() if m else ""

    if candidate:
        s = candidate.replace("“", '"').replace("”", '"').replace("’", "'")
        s = re.sub(r",\s*}", "}", s)
        s = re.sub(r",\s*]", "]", s)
        try:
            obj = json.loads(s)
            if isinstance(obj, dict):
                return obj
        except json.JSONDecodeError:
            pass

    mf = re.search(r'"final"\s*:\s*("?)([^"\}\n]+)\1', text)
    if mf:
        return {"final": mf.group(2).strip()}

    mt = re.search(r'"tool"\s*:\s*"(?P<tool>add|multiply|divide)"', text)
    tool_name = mt.group("tool") if mt else None
    if tool_name is None:
        mt2 = re.search(r'\btool\b\s*:\s*(add|multiply|divide)', text)
        if not mt2:
            raise ValueError(f"Could not parse tool/final from model output:\n{text}")
        tool_name = mt2.group(1)

    ma = re.search(r'"a"\s*:\s*(-?\d+)', text)
    mb = re.search(r'"b"\s*:\s*(-?\d+)', text)
    if not ma or not mb:
        raise ValueError(f"Parsed tool={tool_name} but could not parse a/b from:\n{text}")

    return {"tool": tool_name, "args": {"a": int(ma.group(1)), "b": int(mb.group(1))}}

# 4) Nodes
def llm_call(state: dict):
    prompt = _format_for_phi(state["messages"])
    raw = llm.invoke(prompt)
    raw = _hard_trim_to_first_turn(raw)

    data = _parse_model_output(raw)
    msg = AIMessage(content=raw, additional_kwargs={"parsed": data})
    return {"messages": [msg], "llm_calls": state.get("llm_calls", 0) + 1}

def tool_node(state: dict):
    last = state["messages"][-1]
    data = last.additional_kwargs.get("parsed", {})
    if "tool" not in data:
        return {"messages": []}
    tool_name = data["tool"]
    args = data["args"]
    obs = tools_by_name[tool_name].invoke(args)
    return {"messages": [ToolMessage(content=str(obs), tool_call_id=f"{tool_name}-call")]}

def should_continue(state: MessagesState) -> Literal["tool_node", END]:
    last = state["messages"][-1]
    data = last.additional_kwargs.get("parsed", {})
    return "tool_node" if "tool" in data else END

# 5) Graph
agent_builder = StateGraph(MessagesState)
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("tool_node", tool_node)
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges("llm_call", should_continue, ["tool_node", END])
agent_builder.add_edge("tool_node", "llm_call")
agent = agent_builder.compile()

# 6) Web handler
def run_agent(user_text: str) -> str:
    out = agent.invoke({"messages": [HumanMessage(content=user_text)], "llm_calls": 0})
    # Find final JSON (the last AI message content usually contains it)
    last_ai = None
    for m in reversed(out["messages"]):
        if isinstance(m, AIMessage):
            last_ai = m
            break
    if last_ai is None:
        return "No AI output."
    parsed = last_ai.additional_kwargs.get("parsed", {})
    if "final" in parsed:
        return str(parsed["final"])
    return last_ai.content  # fallback

demo = gr.Interface(
    fn=run_agent,
    inputs=gr.Textbox(label="Ask an arithmetic question", placeholder="e.g., 4 divided by 3"),
    outputs=gr.Textbox(label="Answer"),
    title="Tool-Using Arithmetic Agent (LangGraph + Phi-3-mini)",
)

if __name__ == "__main__":
    demo.launch()