from __future__ import annotations import json import math from typing import Any, Dict, List, Optional, TypedDict from langgraph.graph import StateGraph, END from pydantic import ValidationError from src.agent.schemas import ExtractedData # ---------- JSON safety ---------- def _json_safe(obj): """Recursively convert NaN/inf to None so payload becomes valid JSON.""" if obj is None: return None if isinstance(obj, float): if math.isnan(obj) or math.isinf(obj): return None return obj if isinstance(obj, dict): return {k: _json_safe(v) for k, v in obj.items()} if isinstance(obj, list): return [_json_safe(v) for v in obj] return obj # ---------- OpenAI helper ---------- def _openai_client(api_key: str): from openai import OpenAI # robust for flaky networks return OpenAI(api_key=api_key, timeout=60, max_retries=5) EXTRACT_SYSTEM = """You are a data extraction + validation agent. Your job: convert messy text into STRICT JSON that matches this schema: { "employees": [ { "user_id": int, "name": string, "age": int|null, "email": string|null, "salary": number|null, "join_date": "YYYY-MM-DD"|null, "department": one of ["Artificial Intelligence","AI/ML","Machine Learning","Data Science"], "performance_score": number|null (0..10), "location": string|null, "job_title": string|null } ], "rejected": [ { "raw_record": string, "reasons": [string, ...] } ] } CRITICAL RULES (NO HALLUCINATION): - NEVER invent user_id. If user_id is missing/uncertain, DO NOT guess. Put that record into "rejected" with reason "missing user_id". - NEVER guess values from vague text like "maybe", "around", "probably", "approx". Use null for uncertain optional fields. - If a record cannot be made schema-valid WITHOUT guessing required fields, reject it. - Do not fabricate emails or domains. If email is invalid -> null (or reject only if required, but email is optional here). Normalization rules: - Output JSON ONLY, no markdown. - If a field is missing, set it to null (not empty string). - Normalize department values: AI/ai/Artificial Intelligence -> "Artificial Intelligence" AI/ML -> "AI/ML" ML/Machine Learning -> "Machine Learning" DataScience/Data science -> "Data Science" - Convert word numbers (e.g., "twenty nine") to integers when clear. - Convert dates to ISO YYYY-MM-DD if possible, else null. - Salary: remove $ and commas; if missing, null. - performance_score must be 0..10; if value is out of range or unclear -> null. """ CORRECT_SYSTEM = """You are a self-correcting data validation agent. You will be given: - the previous JSON you produced - a validation error message describing why it failed Fix the JSON to satisfy the schema. CRITICAL RULES (NO HALLUCINATION): - NEVER invent user_id. If user_id is missing/uncertain, reject the record instead of guessing. - NEVER guess uncertain values (maybe/around/probably). Use null for optional fields. - Prefer moving problematic records to "rejected" with clear reasons rather than fabricating data. Rules: - Output JSON ONLY. - Keep valid records in "employees". - Put non-fixable records in "rejected" with reasons. - Use null for missing fields (not empty strings). """ # ---------- LangGraph State ---------- class AgentState(TypedDict): raw_text: str attempt: int max_attempts: int last_json_text: str validation_error: str result: Optional[Dict[str, Any]] log: List[Dict[str, Any]] def _llm_extract(state: AgentState, api_key: str, model: str) -> AgentState: client = _openai_client(api_key) payload = {"raw_text": state["raw_text"]} resp = client.responses.create( model=model, input=[ {"role": "system", "content": EXTRACT_SYSTEM}, {"role": "user", "content": json.dumps(payload)}, ], temperature=0, max_output_tokens=1400, ) out = (resp.output_text or "").strip() state["last_json_text"] = out state["log"].append({"step": "extract", "attempt": state["attempt"], "output": out[:2000]}) return state def _validate(state: AgentState) -> AgentState: try: data = ExtractedData.model_validate_json(state["last_json_text"]) state["result"] = _json_safe(data.model_dump()) state["validation_error"] = "" state["log"].append({"step": "validate", "attempt": state["attempt"], "status": "pass"}) except ValidationError as e: state["result"] = None state["validation_error"] = str(e) state["log"].append( { "step": "validate", "attempt": state["attempt"], "status": "fail", "error": state["validation_error"][:2000], } ) return state def _llm_correct(state: AgentState, api_key: str, model: str) -> AgentState: client = _openai_client(api_key) payload = { "previous_json": state["last_json_text"], "validation_error": state["validation_error"], } resp = client.responses.create( model=model, input=[ {"role": "system", "content": CORRECT_SYSTEM}, {"role": "user", "content": json.dumps(payload)}, ], temperature=0, max_output_tokens=1400, ) out = (resp.output_text or "").strip() state["last_json_text"] = out state["log"].append({"step": "correct", "attempt": state["attempt"], "output": out[:2000]}) return state def _should_retry(state: AgentState) -> str: if state["result"] is not None: return "finalize" if state["attempt"] >= state["max_attempts"]: return "finalize" return "retry" def build_graph(api_key: str, model: str): g = StateGraph(AgentState) g.add_node("extract", lambda s: _llm_extract(s, api_key, model)) g.add_node("validate", _validate) g.add_node("correct", lambda s: _llm_correct(s, api_key, model)) g.set_entry_point("extract") g.add_edge("extract", "validate") g.add_conditional_edges( "validate", _should_retry, {"retry": "correct", "finalize": END}, ) # after correcting, increment attempt then validate again def inc_attempt(state: AgentState) -> AgentState: state["attempt"] += 1 return state g.add_node("inc_attempt", inc_attempt) g.add_edge("correct", "inc_attempt") g.add_edge("inc_attempt", "validate") return g.compile() def run_agent(raw_text: str, api_key: str, model: str = "gpt-4.1-mini", max_attempts: int = 3): graph = build_graph(api_key, model) init: AgentState = { "raw_text": raw_text, "attempt": 1, "max_attempts": max_attempts, "last_json_text": "", "validation_error": "", "result": None, "log": [], } final_state = graph.invoke(init) return final_state