File size: 6,982 Bytes
b67668b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from __future__ import annotations

import json
import math
from typing import Any, Dict, List, Optional, TypedDict

from langgraph.graph import StateGraph, END
from pydantic import ValidationError

from src.agent.schemas import ExtractedData


# ---------- JSON safety ----------
def _json_safe(obj):
    """Recursively convert NaN/inf to None so payload becomes valid JSON."""
    if obj is None:
        return None
    if isinstance(obj, float):
        if math.isnan(obj) or math.isinf(obj):
            return None
        return obj
    if isinstance(obj, dict):
        return {k: _json_safe(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_json_safe(v) for v in obj]
    return obj


# ---------- OpenAI helper ----------
def _openai_client(api_key: str):
    from openai import OpenAI

    # robust for flaky networks
    return OpenAI(api_key=api_key, timeout=60, max_retries=5)


EXTRACT_SYSTEM = """You are a data extraction + validation agent.

Your job: convert messy text into STRICT JSON that matches this schema:

{
  "employees": [
    {
      "user_id": int,
      "name": string,
      "age": int|null,
      "email": string|null,
      "salary": number|null,
      "join_date": "YYYY-MM-DD"|null,
      "department": one of ["Artificial Intelligence","AI/ML","Machine Learning","Data Science"],
      "performance_score": number|null (0..10),
      "location": string|null,
      "job_title": string|null
    }
  ],
  "rejected": [
    { "raw_record": string, "reasons": [string, ...] }
  ]
}

CRITICAL RULES (NO HALLUCINATION):
- NEVER invent user_id. If user_id is missing/uncertain, DO NOT guess.
  Put that record into "rejected" with reason "missing user_id".
- NEVER guess values from vague text like "maybe", "around", "probably", "approx".
  Use null for uncertain optional fields.
- If a record cannot be made schema-valid WITHOUT guessing required fields, reject it.
- Do not fabricate emails or domains. If email is invalid -> null (or reject only if required, but email is optional here).

Normalization rules:
- Output JSON ONLY, no markdown.
- If a field is missing, set it to null (not empty string).
- Normalize department values:
  AI/ai/Artificial Intelligence -> "Artificial Intelligence"
  AI/ML -> "AI/ML"
  ML/Machine Learning -> "Machine Learning"
  DataScience/Data science -> "Data Science"
- Convert word numbers (e.g., "twenty nine") to integers when clear.
- Convert dates to ISO YYYY-MM-DD if possible, else null.
- Salary: remove $ and commas; if missing, null.
- performance_score must be 0..10; if value is out of range or unclear -> null.
"""

CORRECT_SYSTEM = """You are a self-correcting data validation agent.

You will be given:
- the previous JSON you produced
- a validation error message describing why it failed

Fix the JSON to satisfy the schema.

CRITICAL RULES (NO HALLUCINATION):
- NEVER invent user_id. If user_id is missing/uncertain, reject the record instead of guessing.
- NEVER guess uncertain values (maybe/around/probably). Use null for optional fields.
- Prefer moving problematic records to "rejected" with clear reasons rather than fabricating data.

Rules:
- Output JSON ONLY.
- Keep valid records in "employees".
- Put non-fixable records in "rejected" with reasons.
- Use null for missing fields (not empty strings).
"""


# ---------- LangGraph State ----------
class AgentState(TypedDict):
    raw_text: str
    attempt: int
    max_attempts: int
    last_json_text: str
    validation_error: str
    result: Optional[Dict[str, Any]]
    log: List[Dict[str, Any]]


def _llm_extract(state: AgentState, api_key: str, model: str) -> AgentState:
    client = _openai_client(api_key)
    payload = {"raw_text": state["raw_text"]}

    resp = client.responses.create(
        model=model,
        input=[
            {"role": "system", "content": EXTRACT_SYSTEM},
            {"role": "user", "content": json.dumps(payload)},
        ],
        temperature=0,
        max_output_tokens=1400,
    )
    out = (resp.output_text or "").strip()

    state["last_json_text"] = out
    state["log"].append({"step": "extract", "attempt": state["attempt"], "output": out[:2000]})
    return state


def _validate(state: AgentState) -> AgentState:
    try:
        data = ExtractedData.model_validate_json(state["last_json_text"])
        state["result"] = _json_safe(data.model_dump())
        state["validation_error"] = ""
        state["log"].append({"step": "validate", "attempt": state["attempt"], "status": "pass"})
    except ValidationError as e:
        state["result"] = None
        state["validation_error"] = str(e)
        state["log"].append(
            {
                "step": "validate",
                "attempt": state["attempt"],
                "status": "fail",
                "error": state["validation_error"][:2000],
            }
        )
    return state


def _llm_correct(state: AgentState, api_key: str, model: str) -> AgentState:
    client = _openai_client(api_key)
    payload = {
        "previous_json": state["last_json_text"],
        "validation_error": state["validation_error"],
    }

    resp = client.responses.create(
        model=model,
        input=[
            {"role": "system", "content": CORRECT_SYSTEM},
            {"role": "user", "content": json.dumps(payload)},
        ],
        temperature=0,
        max_output_tokens=1400,
    )
    out = (resp.output_text or "").strip()

    state["last_json_text"] = out
    state["log"].append({"step": "correct", "attempt": state["attempt"], "output": out[:2000]})
    return state


def _should_retry(state: AgentState) -> str:
    if state["result"] is not None:
        return "finalize"
    if state["attempt"] >= state["max_attempts"]:
        return "finalize"
    return "retry"


def build_graph(api_key: str, model: str):
    g = StateGraph(AgentState)

    g.add_node("extract", lambda s: _llm_extract(s, api_key, model))
    g.add_node("validate", _validate)
    g.add_node("correct", lambda s: _llm_correct(s, api_key, model))

    g.set_entry_point("extract")
    g.add_edge("extract", "validate")
    g.add_conditional_edges(
        "validate",
        _should_retry,
        {"retry": "correct", "finalize": END},
    )

    # after correcting, increment attempt then validate again
    def inc_attempt(state: AgentState) -> AgentState:
        state["attempt"] += 1
        return state

    g.add_node("inc_attempt", inc_attempt)
    g.add_edge("correct", "inc_attempt")
    g.add_edge("inc_attempt", "validate")

    return g.compile()


def run_agent(raw_text: str, api_key: str, model: str = "gpt-4.1-mini", max_attempts: int = 3):
    graph = build_graph(api_key, model)
    init: AgentState = {
        "raw_text": raw_text,
        "attempt": 1,
        "max_attempts": max_attempts,
        "last_json_text": "",
        "validation_error": "",
        "result": None,
        "log": [],
    }
    final_state = graph.invoke(init)
    return final_state