File size: 1,852 Bytes
91e7690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from __future__ import annotations

from typing import Any, Literal

from pydantic import BaseModel, Field


class FindingConfidence(BaseModel):
    """A single audit finding with agent-reported confidence."""

    value: Any
    confidence: float = Field(ge=0.0, le=1.0)


class AuditReport(BaseModel):
    """Structured audit report submitted by the agent."""

    null_issues: dict[str, FindingConfidence]
    duplicate_row_count: FindingConfidence
    schema_violations: list[dict[str, Any]]
    drifted_columns: list[str]
    drift_details: dict[str, FindingConfidence]
    relational_issues: list[dict[str, Any]]
    recommended_fixes: list[str]


class Action(BaseModel):
    action_type: Literal["query", "submit_report", "fix_sql"]
    sql: str | None = None
    report: AuditReport | None = None


class Observation(BaseModel):
    task_id: int
    task_description: str
    tables: dict[str, dict[str, str]]
    row_counts: dict[str, int]
    step: int
    max_steps: int
    query_credits_remaining: int
    phase: Literal["audit", "fix"]
    last_query_result: list[dict] | None
    last_action_error: str | None
    last_fix_score: float | None


class RewardBreakdown(BaseModel):
    base_audit_score: float
    confidence_brier_adjustment: float
    budget_efficiency_bonus: float
    fix_verification_bonus: float
    total: float


class Reward(BaseModel):
    value: float = Field(ge=-0.5, le=1.25)
    breakdown: RewardBreakdown
    done: bool
    info: dict[str, Any]


class EpisodeState(BaseModel):
    task_id: int
    seed: int
    step: int = 0
    max_steps: int = 12
    query_credits: int = 10
    phase: Literal["audit", "fix"] = "audit"
    fix_steps_remaining: int = 3
    report_submitted: bool = False
    done: bool = False
    gold_faults: dict[str, Any] = {}
    audit_score: float = 0.0
    fix_bonus: float = 0.0