""" Pydantic models for the Invoice Processing Pipeline environment. Action: Agent submits extracted/cleaned/reconciled invoice data as JSON. Observation: Agent receives raw invoice text, feedback, and task context. State: Tracks episode progress, attempts, and scores. """ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Action # --------------------------------------------------------------------------- class InvoiceAction(BaseModel): """Action the agent submits each step.""" extracted_data: Dict[str, Any] = Field( ..., description=( "JSON object with extracted/cleaned invoice fields. " "Structure depends on the task. " "Easy: {vendor, date, currency, total, line_items: [{description, qty, unit_price, amount}]}. " "Medium: {invoices: [{vendor, date, currency, total, line_items}]} (batch of cleaned invoices). " "Hard: {invoices: [...], discrepancies: [{invoice_idx, type, detail, expected, actual}]}. " "Adversarial: same schema as easy — {vendor, date, currency, total, line_items}. " "Negotiate: either {'question': str} to ask a clarification, or the full extraction " "(same schema as easy). " "Supply_chain: {'anomalies': [{'delivery_id', 'anomaly_type', 'detail'}]}." ), ) explanation: str = Field( default="", description="Optional reasoning about extraction or cleaning decisions.", ) # --------------------------------------------------------------------------- # Observation # --------------------------------------------------------------------------- class InvoiceObservation(BaseModel): """What the agent sees each turn.""" raw_text: str = Field(..., description="Raw invoice text (OCR-style or CSV-style)") task_id: str = Field(..., description="easy | medium | hard | expert | adversarial | negotiate | supply_chain | long_horizon | personalized | curriculum") difficulty: str = Field(..., description="Same as task_id") task_description: str = Field(..., description="What the agent should do") attempt_number: int = Field(default=0, description="Current attempt (0 = just reset)") max_attempts: int = Field(default=5, description="Max allowed attempts") feedback: str = Field(default="", description="Detailed grader feedback from last attempt") hint: str = Field(default="", description="Hint shown after 2+ failed attempts") reference_data: str = Field( default="", description="For hard task: purchase order data to reconcile against", ) reward_breakdown: Optional[Dict[str, Any]] = Field( default=None, description=( "Per-field score breakdown for easy, adversarial, and negotiate tasks. " "Example: {'vendor': {'score': 0.15, 'max': 0.15, 'status': 'correct'}, " "'date': {'score': 0.0, 'max': 0.10, 'status': 'wrong'}, ...}" ), ) conversation_history: List[Dict[str, Any]] = Field( default_factory=list, description="For negotiate task: list of {'role': 'agent'|'env', 'content': str} turns.", ) phase: Optional[int] = Field( default=None, description="For long_horizon task: current phase (1=extract, 2=reconcile, 3=audit, 4=forecast).", ) phase_context: Optional[str] = Field( default=None, description="For long_horizon task: accumulated findings from prior phases passed to next phase.", ) agent_profile: Optional[Dict[str, Any]] = Field( default=None, description="For personalized task: agent's historical performance profile used to adapt difficulty.", ) # --------------------------------------------------------------------------- # State # --------------------------------------------------------------------------- class InvoiceState(BaseModel): """Internal episode state.""" episode_id: str = Field(default="") task_id: str = Field(default="easy") step_count: int = Field(default=0) done: bool = Field(default=False) last_reward: float = Field(default=0.0) best_reward: float = Field(default=0.0) rewards: List[float] = Field(default_factory=list) conversation_history: List[Dict[str, Any]] = Field(default_factory=list) clarification_count: int = Field(default=0) # Long-horizon: tracks which phase and accumulated context phase: int = Field(default=1) phase_scores: List[float] = Field(default_factory=list) phase_context: str = Field(default="") # Personalized: tracks agent weak areas across steps agent_profile: Dict[str, Any] = Field(default_factory=dict) # --------------------------------------------------------------------------- # Supply Chain (documentation model) # --------------------------------------------------------------------------- class SupplyChainAnomalyItem(BaseModel): delivery_id: str anomaly_type: str # quantity_shortfall | price_spike | unauthorized_substitution | phantom_delivery detail: str class SupplyChainAction(BaseModel): anomalies: List[SupplyChainAnomalyItem]