File size: 5,260 Bytes
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59a05a5
 
 
 
 
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4890422
0bf71ce
 
 
 
 
 
 
 
 
 
59a05a5
 
 
 
 
 
 
 
 
 
 
 
473ab10
 
 
 
 
 
 
 
 
 
 
 
0bf71ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59a05a5
 
 
473ab10
 
 
 
 
 
59a05a5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Pydantic models for the Invoice Processing Pipeline environment.

Action:  Agent submits extracted/cleaned/reconciled invoice data as JSON.
Observation: Agent receives raw invoice text, feedback, and task context.
State:   Tracks episode progress, attempts, and scores.
"""

from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field


# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------

class InvoiceAction(BaseModel):
    """Action the agent submits each step."""

    extracted_data: Dict[str, Any] = Field(
        ...,
        description=(
            "JSON object with extracted/cleaned invoice fields. "
            "Structure depends on the task. "
            "Easy: {vendor, date, currency, total, line_items: [{description, qty, unit_price, amount}]}. "
            "Medium: {invoices: [{vendor, date, currency, total, line_items}]} (batch of cleaned invoices). "
            "Hard: {invoices: [...], discrepancies: [{invoice_idx, type, detail, expected, actual}]}. "
            "Adversarial: same schema as easy — {vendor, date, currency, total, line_items}. "
            "Negotiate: either {'question': str} to ask a clarification, or the full extraction "
            "(same schema as easy). "
            "Supply_chain: {'anomalies': [{'delivery_id', 'anomaly_type', 'detail'}]}."
        ),
    )
    explanation: str = Field(
        default="",
        description="Optional reasoning about extraction or cleaning decisions.",
    )


# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------

class InvoiceObservation(BaseModel):
    """What the agent sees each turn."""

    raw_text: str = Field(..., description="Raw invoice text (OCR-style or CSV-style)")
    task_id: str = Field(..., description="easy | medium | hard | expert | adversarial | negotiate | supply_chain | long_horizon | personalized | curriculum")
    difficulty: str = Field(..., description="Same as task_id")
    task_description: str = Field(..., description="What the agent should do")
    attempt_number: int = Field(default=0, description="Current attempt (0 = just reset)")
    max_attempts: int = Field(default=5, description="Max allowed attempts")
    feedback: str = Field(default="", description="Detailed grader feedback from last attempt")
    hint: str = Field(default="", description="Hint shown after 2+ failed attempts")
    reference_data: str = Field(
        default="",
        description="For hard task: purchase order data to reconcile against",
    )
    reward_breakdown: Optional[Dict[str, Any]] = Field(
        default=None,
        description=(
            "Per-field score breakdown for easy, adversarial, and negotiate tasks. "
            "Example: {'vendor': {'score': 0.15, 'max': 0.15, 'status': 'correct'}, "
            "'date': {'score': 0.0, 'max': 0.10, 'status': 'wrong'}, ...}"
        ),
    )
    conversation_history: List[Dict[str, Any]] = Field(
        default_factory=list,
        description="For negotiate task: list of {'role': 'agent'|'env', 'content': str} turns.",
    )
    phase: Optional[int] = Field(
        default=None,
        description="For long_horizon task: current phase (1=extract, 2=reconcile, 3=audit, 4=forecast).",
    )
    phase_context: Optional[str] = Field(
        default=None,
        description="For long_horizon task: accumulated findings from prior phases passed to next phase.",
    )
    agent_profile: Optional[Dict[str, Any]] = Field(
        default=None,
        description="For personalized task: agent's historical performance profile used to adapt difficulty.",
    )


# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------

class InvoiceState(BaseModel):
    """Internal episode state."""

    episode_id: str = Field(default="")
    task_id: str = Field(default="easy")
    step_count: int = Field(default=0)
    done: bool = Field(default=False)
    last_reward: float = Field(default=0.0)
    best_reward: float = Field(default=0.0)
    rewards: List[float] = Field(default_factory=list)
    conversation_history: List[Dict[str, Any]] = Field(default_factory=list)
    clarification_count: int = Field(default=0)
    # Long-horizon: tracks which phase and accumulated context
    phase: int = Field(default=1)
    phase_scores: List[float] = Field(default_factory=list)
    phase_context: str = Field(default="")
    # Personalized: tracks agent weak areas across steps
    agent_profile: Dict[str, Any] = Field(default_factory=dict)


# ---------------------------------------------------------------------------
# Supply Chain (documentation model)
# ---------------------------------------------------------------------------

class SupplyChainAnomalyItem(BaseModel):
    delivery_id: str
    anomaly_type: str  # quantity_shortfall | price_spike | unauthorized_substitution | phantom_delivery
    detail: str


class SupplyChainAction(BaseModel):
    anomalies: List[SupplyChainAnomalyItem]