File size: 6,982 Bytes
b67668b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | from __future__ import annotations
import json
import math
from typing import Any, Dict, List, Optional, TypedDict
from langgraph.graph import StateGraph, END
from pydantic import ValidationError
from src.agent.schemas import ExtractedData
# ---------- JSON safety ----------
def _json_safe(obj):
"""Recursively convert NaN/inf to None so payload becomes valid JSON."""
if obj is None:
return None
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return None
return obj
if isinstance(obj, dict):
return {k: _json_safe(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_json_safe(v) for v in obj]
return obj
# ---------- OpenAI helper ----------
def _openai_client(api_key: str):
from openai import OpenAI
# robust for flaky networks
return OpenAI(api_key=api_key, timeout=60, max_retries=5)
EXTRACT_SYSTEM = """You are a data extraction + validation agent.
Your job: convert messy text into STRICT JSON that matches this schema:
{
"employees": [
{
"user_id": int,
"name": string,
"age": int|null,
"email": string|null,
"salary": number|null,
"join_date": "YYYY-MM-DD"|null,
"department": one of ["Artificial Intelligence","AI/ML","Machine Learning","Data Science"],
"performance_score": number|null (0..10),
"location": string|null,
"job_title": string|null
}
],
"rejected": [
{ "raw_record": string, "reasons": [string, ...] }
]
}
CRITICAL RULES (NO HALLUCINATION):
- NEVER invent user_id. If user_id is missing/uncertain, DO NOT guess.
Put that record into "rejected" with reason "missing user_id".
- NEVER guess values from vague text like "maybe", "around", "probably", "approx".
Use null for uncertain optional fields.
- If a record cannot be made schema-valid WITHOUT guessing required fields, reject it.
- Do not fabricate emails or domains. If email is invalid -> null (or reject only if required, but email is optional here).
Normalization rules:
- Output JSON ONLY, no markdown.
- If a field is missing, set it to null (not empty string).
- Normalize department values:
AI/ai/Artificial Intelligence -> "Artificial Intelligence"
AI/ML -> "AI/ML"
ML/Machine Learning -> "Machine Learning"
DataScience/Data science -> "Data Science"
- Convert word numbers (e.g., "twenty nine") to integers when clear.
- Convert dates to ISO YYYY-MM-DD if possible, else null.
- Salary: remove $ and commas; if missing, null.
- performance_score must be 0..10; if value is out of range or unclear -> null.
"""
CORRECT_SYSTEM = """You are a self-correcting data validation agent.
You will be given:
- the previous JSON you produced
- a validation error message describing why it failed
Fix the JSON to satisfy the schema.
CRITICAL RULES (NO HALLUCINATION):
- NEVER invent user_id. If user_id is missing/uncertain, reject the record instead of guessing.
- NEVER guess uncertain values (maybe/around/probably). Use null for optional fields.
- Prefer moving problematic records to "rejected" with clear reasons rather than fabricating data.
Rules:
- Output JSON ONLY.
- Keep valid records in "employees".
- Put non-fixable records in "rejected" with reasons.
- Use null for missing fields (not empty strings).
"""
# ---------- LangGraph State ----------
class AgentState(TypedDict):
raw_text: str
attempt: int
max_attempts: int
last_json_text: str
validation_error: str
result: Optional[Dict[str, Any]]
log: List[Dict[str, Any]]
def _llm_extract(state: AgentState, api_key: str, model: str) -> AgentState:
client = _openai_client(api_key)
payload = {"raw_text": state["raw_text"]}
resp = client.responses.create(
model=model,
input=[
{"role": "system", "content": EXTRACT_SYSTEM},
{"role": "user", "content": json.dumps(payload)},
],
temperature=0,
max_output_tokens=1400,
)
out = (resp.output_text or "").strip()
state["last_json_text"] = out
state["log"].append({"step": "extract", "attempt": state["attempt"], "output": out[:2000]})
return state
def _validate(state: AgentState) -> AgentState:
try:
data = ExtractedData.model_validate_json(state["last_json_text"])
state["result"] = _json_safe(data.model_dump())
state["validation_error"] = ""
state["log"].append({"step": "validate", "attempt": state["attempt"], "status": "pass"})
except ValidationError as e:
state["result"] = None
state["validation_error"] = str(e)
state["log"].append(
{
"step": "validate",
"attempt": state["attempt"],
"status": "fail",
"error": state["validation_error"][:2000],
}
)
return state
def _llm_correct(state: AgentState, api_key: str, model: str) -> AgentState:
client = _openai_client(api_key)
payload = {
"previous_json": state["last_json_text"],
"validation_error": state["validation_error"],
}
resp = client.responses.create(
model=model,
input=[
{"role": "system", "content": CORRECT_SYSTEM},
{"role": "user", "content": json.dumps(payload)},
],
temperature=0,
max_output_tokens=1400,
)
out = (resp.output_text or "").strip()
state["last_json_text"] = out
state["log"].append({"step": "correct", "attempt": state["attempt"], "output": out[:2000]})
return state
def _should_retry(state: AgentState) -> str:
if state["result"] is not None:
return "finalize"
if state["attempt"] >= state["max_attempts"]:
return "finalize"
return "retry"
def build_graph(api_key: str, model: str):
g = StateGraph(AgentState)
g.add_node("extract", lambda s: _llm_extract(s, api_key, model))
g.add_node("validate", _validate)
g.add_node("correct", lambda s: _llm_correct(s, api_key, model))
g.set_entry_point("extract")
g.add_edge("extract", "validate")
g.add_conditional_edges(
"validate",
_should_retry,
{"retry": "correct", "finalize": END},
)
# after correcting, increment attempt then validate again
def inc_attempt(state: AgentState) -> AgentState:
state["attempt"] += 1
return state
g.add_node("inc_attempt", inc_attempt)
g.add_edge("correct", "inc_attempt")
g.add_edge("inc_attempt", "validate")
return g.compile()
def run_agent(raw_text: str, api_key: str, model: str = "gpt-4.1-mini", max_attempts: int = 3):
graph = build_graph(api_key, model)
init: AgentState = {
"raw_text": raw_text,
"attempt": 1,
"max_attempts": max_attempts,
"last_json_text": "",
"validation_error": "",
"result": None,
"log": [],
}
final_state = graph.invoke(init)
return final_state
|