Spaces:
Sleeping
Sleeping
File size: 5,473 Bytes
8865795 979cd0f 8865795 3d29713 8865795 979cd0f 8865795 979cd0f 8865795 979cd0f 8865795 3d29713 8865795 3d29713 8865795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import json
import re
import operator
from typing import Literal, Dict, Any
from typing_extensions import TypedDict, Annotated
import gradio as gr
from langchain_core.tools import tool # (recommended import)
from langchain_huggingface import HuggingFacePipeline
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
from langgraph.graph import StateGraph, START, END
# 1) Model
llm = HuggingFacePipeline.from_model_id(
model_id="microsoft/Phi-3-mini-4k-instruct",
task="text-generation",
pipeline_kwargs={
"max_new_tokens": 96,
"top_k": 50,
"temperature": 0.1,
"return_full_text": False,
},
)
# 2) Tools
@tool
def multiply(a: int, b: int) -> int:
"""Multiply a and b."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add a and b."""
return a + b
@tool
def divide(a: int, b: int) -> float:
"""Divide a by b."""
return a / b
tools = [add, multiply, divide]
tools_by_name = {t.name: t for t in tools}
# 3) State
class MessagesState(TypedDict):
messages: Annotated[list[BaseMessage], operator.add]
llm_calls: int
SYSTEM = """You are an arithmetic tool user.
You must output exactly ONE object, preferably STRICT JSON.
If there is NO tool result yet, output a tool call:
{"tool": "add"|"multiply"|"divide", "args": {"a": <int>, "b": <int>}}
If there IS a tool result already, output the final answer:
{"final": "<answer>"}
No extra text. Use double quotes if possible.
""".strip()
def _format_for_phi(messages: list[BaseMessage]) -> str:
parts = [SYSTEM, ""]
for m in messages:
if isinstance(m, HumanMessage):
parts.append(f"User: {m.content}")
elif isinstance(m, ToolMessage):
parts.append(f"Tool result: {m.content}")
parts.append("Assistant:")
return "\n".join(parts)
def _hard_trim_to_first_turn(text: str) -> str:
cuts = ["\n\nUser:", "\nUser:", "\n\nAssistant:", "\nAssistant:"]
for c in cuts:
if c in text:
text = text.split(c, 1)[0]
return text.strip()
def _parse_model_output(text: str) -> Dict[str, Any]:
m = re.search(r"\{.*?\}", text, flags=re.DOTALL)
candidate = m.group(0).strip() if m else ""
if candidate:
s = candidate.replace("“", '"').replace("”", '"').replace("’", "'")
s = re.sub(r",\s*}", "}", s)
s = re.sub(r",\s*]", "]", s)
try:
obj = json.loads(s)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
mf = re.search(r'"final"\s*:\s*("?)([^"\}\n]+)\1', text)
if mf:
return {"final": mf.group(2).strip()}
mt = re.search(r'"tool"\s*:\s*"(?P<tool>add|multiply|divide)"', text)
tool_name = mt.group("tool") if mt else None
if tool_name is None:
mt2 = re.search(r'\btool\b\s*:\s*(add|multiply|divide)', text)
if not mt2:
raise ValueError(f"Could not parse tool/final from model output:\n{text}")
tool_name = mt2.group(1)
ma = re.search(r'"a"\s*:\s*(-?\d+)', text)
mb = re.search(r'"b"\s*:\s*(-?\d+)', text)
if not ma or not mb:
raise ValueError(f"Parsed tool={tool_name} but could not parse a/b from:\n{text}")
return {"tool": tool_name, "args": {"a": int(ma.group(1)), "b": int(mb.group(1))}}
# 4) Nodes
def llm_call(state: dict):
prompt = _format_for_phi(state["messages"])
raw = llm.invoke(prompt)
raw = _hard_trim_to_first_turn(raw)
data = _parse_model_output(raw)
msg = AIMessage(content=raw, additional_kwargs={"parsed": data})
return {"messages": [msg], "llm_calls": state.get("llm_calls", 0) + 1}
def tool_node(state: dict):
last = state["messages"][-1]
data = last.additional_kwargs.get("parsed", {})
if "tool" not in data:
return {"messages": []}
tool_name = data["tool"]
args = data["args"]
obs = tools_by_name[tool_name].invoke(args)
return {"messages": [ToolMessage(content=str(obs), tool_call_id=f"{tool_name}-call")]}
def should_continue(state: MessagesState) -> Literal["tool_node", END]:
last = state["messages"][-1]
data = last.additional_kwargs.get("parsed", {})
return "tool_node" if "tool" in data else END
# 5) Graph
agent_builder = StateGraph(MessagesState)
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("tool_node", tool_node)
agent_builder.add_edge(START, "llm_call")
agent_builder.add_conditional_edges("llm_call", should_continue, ["tool_node", END])
agent_builder.add_edge("tool_node", "llm_call")
agent = agent_builder.compile()
# 6) Web handler
def run_agent(user_text: str) -> str:
out = agent.invoke({"messages": [HumanMessage(content=user_text)], "llm_calls": 0})
# Find final JSON (the last AI message content usually contains it)
last_ai = None
for m in reversed(out["messages"]):
if isinstance(m, AIMessage):
last_ai = m
break
if last_ai is None:
return "No AI output."
parsed = last_ai.additional_kwargs.get("parsed", {})
if "final" in parsed:
return str(parsed["final"])
return last_ai.content # fallback
demo = gr.Interface(
fn=run_agent,
inputs=gr.Textbox(label="Ask an arithmetic question", placeholder="e.g., 4 divided by 3"),
outputs=gr.Textbox(label="Answer"),
title="Tool-Using Arithmetic Agent (LangGraph + Phi-3-mini)",
)
if __name__ == "__main__":
demo.launch()
|