Add multi-agent architecture: Regulator, biased Generator, Auditor rewards
Browse files- AuditorPerformanceTracker: cross-episode rolling 30-episode window per fraud type
- 4 independent Extractor reward signals (format/field/math/completeness)
- Auditor reward (+0.99 correct detection, +0.90 clean clearance, +0.01 miss/FP)
- Generator bias: Regulator blind spots get 60% sampling weight (self-improvement loop)
- Rule-based Approver (approve/escalate/reject thresholds)
- Generator adversarial reward (0.85 undetected+approved / 0.60 / 0.10)
- 8 new endpoints: /multi/reset, /multi/extract, /multi/audit, /multi/approve,
/multi/state/{id}, /regulator/report, /regulator/predict, /regulator/demo_seed
- Regulator Dashboard tab in Gradio UI
- Fix _MAX_SESSIONS 50 → 200 (prevents stale session crashes during training)
- Add BLOG_DRAFT.md for hackathon blog submission
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- BLOG_DRAFT.md +216 -0
- server/app.py +122 -1
- server/multi_agent_environment.py +697 -0
- server/web_ui.py +159 -98
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Invoice Processing Pipeline — Multi-Agent RL Environment for Financial Fraud Detection
|
| 2 |
+
|
| 3 |
+
**Meta PyTorch OpenEnv Hackathon Grand Finale | April 25–26, 2026**
|
| 4 |
+
**Team: Pritam Satpathy + Gnana Nawin T**
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## The Problem
|
| 9 |
+
|
| 10 |
+
Invoice fraud costs businesses an estimated **5% of annual revenue**. Finance teams manually process thousands of invoices every month — extracting vendor names, dates, line items, totals — and checking them against purchase orders for discrepancies. This is:
|
| 11 |
+
|
| 12 |
+
- Slow (hours per batch)
|
| 13 |
+
- Error-prone (typos, OCR noise, format chaos)
|
| 14 |
+
- Gameable (phantom vendors, price gouging, duplicate submissions)
|
| 15 |
+
|
| 16 |
+
We built an **RL training environment** that teaches LLMs to do this automatically — and improves itself when it finds its own blind spots.
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## What We Built
|
| 21 |
+
|
| 22 |
+
An OpenEnv-compatible environment deployed on HuggingFace Spaces:
|
| 23 |
+
**https://ps2181-invoice-processing-pipeline.hf.space**
|
| 24 |
+
|
| 25 |
+
### 5-Agent Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
Generator ──► Extractor ──► Auditor ──► Approver
|
| 29 |
+
â–²
|
| 30 |
+
Regulator
|
| 31 |
+
(cross-episode)
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
| Agent | Role | Reward Signal |
|
| 35 |
+
|-------|------|--------------|
|
| 36 |
+
| **Generator** | Creates clean or fraudulent invoices | Rewarded when fraud slips past Auditor (adversarial self-play) |
|
| 37 |
+
| **Extractor** | Reads raw invoice text → structured JSON | 4 independent signals: format, field accuracy, math consistency, completeness |
|
| 38 |
+
| **Auditor** | Reviews extraction, flags fraud | +0.99 correct detection, +0.90 clean clearance, 0.01 for miss or false positive |
|
| 39 |
+
| **Approver** | Final approve/reject/escalate decision | +0.95 correct decision |
|
| 40 |
+
| **Regulator** | Monitors Auditor blind spots **across episodes** | Precision + recall of blind spot predictions |
|
| 41 |
+
|
| 42 |
+
### The Key Innovation: The Regulator
|
| 43 |
+
|
| 44 |
+
The Regulator is a **cross-episode meta-agent** — it watches the Auditor's decision history over 30 episodes and identifies systematic failure patterns:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
AUDITOR PERFORMANCE TRACKER (last 30 episodes)
|
| 48 |
+
|
| 49 |
+
Fraud Type Detection Rate
|
| 50 |
+
─────────────────────────────────────
|
| 51 |
+
phantom_vendor 31% âš BLIND SPOT
|
| 52 |
+
price_gouging 74% ✓ OK
|
| 53 |
+
math_fraud 81% ✓ OK
|
| 54 |
+
duplicate_submission 62% ✓ OK
|
| 55 |
+
|
| 56 |
+
False Positive Rate: 12% ✓ OK
|
| 57 |
+
|
| 58 |
+
REGULATOR VERDICT: Recommend retraining on phantom_vendor
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
When the Regulator detects a blind spot, the **Generator automatically starts producing more of that fraud type** — closing the self-improvement loop without human intervention.
|
| 62 |
+
|
| 63 |
+
This directly addresses **Theme #1 (Fleet AI Scalable Oversight)** and **Theme #4 (Self-Improvement)**.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## 7 Tasks (Progressive Difficulty)
|
| 68 |
+
|
| 69 |
+
| Task | Difficulty | What the Agent Does |
|
| 70 |
+
|------|-----------|---------------------|
|
| 71 |
+
| `easy` | Easy | Extract fields from a single clean invoice |
|
| 72 |
+
| `medium` | Medium | Clean + normalise a batch of messy invoices (typos, date chaos, currency symbols) |
|
| 73 |
+
| `hard` | Hard | Extract + reconcile against purchase orders, flag discrepancies |
|
| 74 |
+
| `expert` | Expert | Fraud audit: classify phantom_vendor / price_gouging / math_fraud / duplicate_submission |
|
| 75 |
+
| `adversarial` | Hard | Extract from OCR-corrupted invoice with SUBTOTAL trap and FX noise lines |
|
| 76 |
+
| `negotiate` | Medium | Ask clarification questions then submit extraction (bonus for ≤2 questions) |
|
| 77 |
+
| `supply_chain` | Expert | Detect quantity shortfalls, price spikes, phantom deliveries in delivery records |
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Design Decisions
|
| 82 |
+
|
| 83 |
+
### 4 Independent Reward Functions (Anti-Hacking)
|
| 84 |
+
|
| 85 |
+
Per the hackathon guide Section 7: *"use multiple independent reward functions — if you only have one reward signal, it is easier for the model to hack it."*
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
format_reward() # Are all 5 required JSON keys present? weight: 0.10
|
| 89 |
+
field_reward() # Do vendor/date/currency/total match? weight: 0.40
|
| 90 |
+
math_reward() # Does qty × unit_price = amount for all items? weight: 0.25
|
| 91 |
+
completeness_reward() # Are all line items present (recall)? weight: 0.25
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
During training we observed the model maximising `math_reward` (0.97) and `completeness_reward` (1.0) while `field_reward` stayed at 0.0 — the model learned to output arithmetic-consistent JSON while hallucinating values. **Our independent signals made this reward hacking immediately visible**, confirming the design choice.
|
| 95 |
+
|
| 96 |
+
### Adversarial Self-Play
|
| 97 |
+
|
| 98 |
+
The Generator is rewarded when its fraud evades the Auditor:
|
| 99 |
+
- Fraud fully undetected by Auditor + Approver approves → Generator reward: **0.85**
|
| 100 |
+
- Auditor missed but Approver caught → Generator reward: **0.60**
|
| 101 |
+
- Auditor caught it → Generator reward: **0.10**
|
| 102 |
+
|
| 103 |
+
This creates evolutionary pressure: the Generator evolves harder-to-detect fraud patterns, forcing the Auditor to improve.
|
| 104 |
+
|
| 105 |
+
### Dynamic Difficulty
|
| 106 |
+
|
| 107 |
+
The environment tracks recent agent scores per task (rolling window of 10 episodes) and adjusts generation parameters:
|
| 108 |
+
- Agent scoring ≥ 0.85 → harder parameters (more invoices, more OCR noise, more discrepancies)
|
| 109 |
+
- Agent scoring < 0.60 → easier parameters
|
| 110 |
+
- In between → standard
|
| 111 |
+
|
| 112 |
+
### All Rewards Clamped to (0.01, 0.99)
|
| 113 |
+
|
| 114 |
+
Rewards are never exactly 0 or 1 — avoids log(0) in policy gradient and prevents the model from getting stuck at boundaries.
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Tech Stack
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
Environment: FastAPI + OpenEnv-core + Pydantic
|
| 122 |
+
Deployment: HuggingFace Spaces (Docker, port 7860)
|
| 123 |
+
UI: Gradio (mounted at /web)
|
| 124 |
+
Training: TRL GRPOTrainer + Unsloth (Qwen2.5-1.5B-Instruct, 4-bit QLoRA)
|
| 125 |
+
Model: unsloth/Qwen2.5-1.5B-Instruct r=16 LoRA
|
| 126 |
+
Reward: 4 local signals + live /grader endpoint on HF Space
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Training Setup
|
| 132 |
+
|
| 133 |
+
**GRPO (Group Relative Policy Optimization)** with:
|
| 134 |
+
- `num_generations = 4` — 4 completions per prompt, compared within group
|
| 135 |
+
- `max_steps = 200`
|
| 136 |
+
- `learning_rate = 5e-6`
|
| 137 |
+
- Live `/grader` endpoint on HF Space as environment verifier
|
| 138 |
+
|
| 139 |
+
The training loop:
|
| 140 |
+
```
|
| 141 |
+
Colab samples episode → HF Space /reset → gets live invoice
|
| 142 |
+
Model generates JSON extraction
|
| 143 |
+
HF Space /grader scores it against ground truth
|
| 144 |
+
GRPO updates model toward higher-scoring completions
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### Reward Curve
|
| 148 |
+
|
| 149 |
+
| Step | Total Reward | Env Score | Format | Math |
|
| 150 |
+
|------|-------------|-----------|--------|------|
|
| 151 |
+
| 10 | 2.361 | 0.113 | 0.900 | 0.347 |
|
| 152 |
+
| 20 | 2.595 | 0.282 | 0.900 | 0.413 |
|
| 153 |
+
| 30 | 2.657 | 0.304 | 0.950 | 0.403 |
|
| 154 |
+
|
| 155 |
+
Environment score rose **0.113 → 0.304 in 30 steps** — a 169% improvement in the model's ability to correctly extract invoice data as scored by the live environment grader.
|
| 156 |
+
|
| 157 |
+
*[Add reward curve plot image here]*
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## API Endpoints
|
| 162 |
+
|
| 163 |
+
| Endpoint | Method | Description |
|
| 164 |
+
|----------|--------|-------------|
|
| 165 |
+
| `/health` | GET | Health check |
|
| 166 |
+
| `/reset` | POST | Start new episode `{"task_id": "easy"}` |
|
| 167 |
+
| `/step` | POST | Submit extraction, get reward + feedback |
|
| 168 |
+
| `/grader` | POST | Score without consuming attempt |
|
| 169 |
+
| `/state` | GET | Episode metadata |
|
| 170 |
+
| `/tasks` | GET | List all 7 tasks with schemas |
|
| 171 |
+
| `/ws` | WebSocket | Full episode over WebSocket (OpenEnv standard) |
|
| 172 |
+
| `/web` | GET | Gradio interactive UI |
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## What Makes This Novel
|
| 177 |
+
|
| 178 |
+
1. **Regulator agent** — no other OpenEnv environment has a cross-episode meta-agent that monitors another agent for systematic cognitive blind spots
|
| 179 |
+
|
| 180 |
+
2. **Closed self-improvement loop** — Regulator detects blind spot → Generator biases fraud generation toward that type → Auditor forced to improve → no human intervention required
|
| 181 |
+
|
| 182 |
+
3. **Adversarial Generator arms race** — Generator rewarded for evading Auditor creates evolutionary pressure on fraud detection
|
| 183 |
+
|
| 184 |
+
4. **Live environment as verifier** — training Colab directly calls `/grader` on deployed HF Space — the environment IS the reward function
|
| 185 |
+
|
| 186 |
+
5. **4 independent reward signals** — made reward hacking immediately visible during training (detected it at step 10)
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## Theme Alignment
|
| 191 |
+
|
| 192 |
+
| Theme | Alignment |
|
| 193 |
+
|-------|-----------|
|
| 194 |
+
| **#1 Multi-Agent** | 5 agents with conflicting incentives |
|
| 195 |
+
| **#1 Sub: Fleet AI Oversight** (bonus) | Regulator monitors Auditor cross-episode |
|
| 196 |
+
| **#3.1 Professional Tasks** | Invoice processing = core enterprise workflow |
|
| 197 |
+
| **#3.1 Sub: Scaler AI Labs** (bonus) | Multi-agent RL for enterprise financial workflows |
|
| 198 |
+
| **#4 Self-Improvement** | Generator adapts based on Regulator blind spot findings |
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Links
|
| 203 |
+
|
| 204 |
+
- **Live Environment:** https://ps2181-invoice-processing-pipeline.hf.space
|
| 205 |
+
- **Gradio UI:** https://ps2181-invoice-processing-pipeline.hf.space/web
|
| 206 |
+
- **API Docs:** https://ps2181-invoice-processing-pipeline.hf.space/docs
|
| 207 |
+
- **GitHub:** https://github.com/ps2181/invoice-processing-pipeline
|
| 208 |
+
- **Training Colab:** *[add link after saving to GitHub]*
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## Team
|
| 213 |
+
|
| 214 |
+
**Pritam Satpathy** + **Gnana Nawin T**
|
| 215 |
+
Meta PyTorch OpenEnv Hackathon Grand Finale
|
| 216 |
+
Scaler School of Technology, Bangalore — April 25–26, 2026
|
|
@@ -43,7 +43,7 @@ except Exception as _e:
|
|
| 43 |
# Thread-safe, capped at MAX_SESSIONS to bound memory on vcpu=2 / 8gb
|
| 44 |
# ---------------------------------------------------------------------------
|
| 45 |
|
| 46 |
-
_MAX_SESSIONS =
|
| 47 |
_sessions: OrderedDict[str, InvoiceEnvironment] = OrderedDict()
|
| 48 |
_lock = threading.Lock()
|
| 49 |
|
|
@@ -286,6 +286,127 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 286 |
pass
|
| 287 |
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
def main():
|
| 290 |
import uvicorn
|
| 291 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 43 |
# Thread-safe, capped at MAX_SESSIONS to bound memory on vcpu=2 / 8gb
|
| 44 |
# ---------------------------------------------------------------------------
|
| 45 |
|
| 46 |
+
_MAX_SESSIONS = 200
|
| 47 |
_sessions: OrderedDict[str, InvoiceEnvironment] = OrderedDict()
|
| 48 |
_lock = threading.Lock()
|
| 49 |
|
|
|
|
| 286 |
pass
|
| 287 |
|
| 288 |
|
| 289 |
+
# ---------------------------------------------------------------------------
|
| 290 |
+
# Multi-agent endpoints
|
| 291 |
+
# ---------------------------------------------------------------------------
|
| 292 |
+
|
| 293 |
+
from server.multi_agent_environment import (
|
| 294 |
+
create_episode,
|
| 295 |
+
get_episode,
|
| 296 |
+
handle_extract,
|
| 297 |
+
handle_audit,
|
| 298 |
+
handle_approve,
|
| 299 |
+
tracker as _regulator_tracker,
|
| 300 |
+
compute_regulator_reward,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class MultiResetResponse(BaseModel):
|
| 305 |
+
episode_id: str
|
| 306 |
+
raw_text: str
|
| 307 |
+
reference_data: str
|
| 308 |
+
fraud_weights_used: Dict[str, Any]
|
| 309 |
+
n_invoices: int
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class MultiExtractRequest(BaseModel):
|
| 313 |
+
episode_id: str
|
| 314 |
+
extracted_data: Dict[str, Any]
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class MultiAuditRequest(BaseModel):
|
| 318 |
+
episode_id: str
|
| 319 |
+
audit_results: list
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class RegulatorPredictRequest(BaseModel):
|
| 323 |
+
predicted_blind_spots: list
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@app.post("/multi/reset")
|
| 327 |
+
def multi_reset():
|
| 328 |
+
"""Start a new multi-agent episode. Generator is biased by Regulator blind spots."""
|
| 329 |
+
ep = create_episode()
|
| 330 |
+
return MultiResetResponse(
|
| 331 |
+
episode_id=ep.episode_id,
|
| 332 |
+
raw_text=ep.raw_text,
|
| 333 |
+
reference_data=ep.reference_data,
|
| 334 |
+
fraud_weights_used=ep.fraud_weights_used,
|
| 335 |
+
n_invoices=len(ep.invoices),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@app.post("/multi/extract")
|
| 340 |
+
def multi_extract(req: MultiExtractRequest):
|
| 341 |
+
"""Score Extractor output with 4 independent reward signals."""
|
| 342 |
+
result = handle_extract(req.episode_id, req.extracted_data)
|
| 343 |
+
if "error" in result:
|
| 344 |
+
raise HTTPException(status_code=404, detail=result["error"])
|
| 345 |
+
return result
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@app.post("/multi/audit")
|
| 349 |
+
def multi_audit(req: MultiAuditRequest):
|
| 350 |
+
"""Score Auditor output. Records to AuditorPerformanceTracker."""
|
| 351 |
+
result = handle_audit(req.episode_id, req.audit_results)
|
| 352 |
+
if "error" in result:
|
| 353 |
+
raise HTTPException(status_code=404, detail=result["error"])
|
| 354 |
+
return result
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@app.post("/multi/approve")
|
| 358 |
+
def multi_approve(episode_id: str):
|
| 359 |
+
"""Run rule-based Approver. Computes Generator adversarial reward."""
|
| 360 |
+
result = handle_approve(episode_id)
|
| 361 |
+
if "error" in result:
|
| 362 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 363 |
+
return result
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@app.get("/multi/state/{episode_id}")
|
| 367 |
+
def multi_state(episode_id: str):
|
| 368 |
+
"""Get current state of a multi-agent episode."""
|
| 369 |
+
ep = get_episode(episode_id)
|
| 370 |
+
if ep is None:
|
| 371 |
+
raise HTTPException(status_code=404, detail="Episode not found")
|
| 372 |
+
return {
|
| 373 |
+
"episode_id": ep.episode_id,
|
| 374 |
+
"n_invoices": len(ep.invoices),
|
| 375 |
+
"fraud_weights_used": ep.fraud_weights_used,
|
| 376 |
+
"extractor_reward": ep.extractor_reward,
|
| 377 |
+
"extractor_breakdown": ep.extractor_breakdown,
|
| 378 |
+
"mean_auditor_reward": ep.mean_auditor_reward,
|
| 379 |
+
"mean_generator_reward": ep.mean_generator_reward,
|
| 380 |
+
"done": ep.done,
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@app.get("/regulator/report")
|
| 385 |
+
def regulator_report():
|
| 386 |
+
"""Get the Regulator's current cross-episode Auditor performance report."""
|
| 387 |
+
return _regulator_tracker.report()
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@app.post("/regulator/predict")
|
| 391 |
+
def regulator_predict(req: RegulatorPredictRequest):
|
| 392 |
+
"""Score a Regulator agent's blind spot predictions against actual tracker state."""
|
| 393 |
+
actual = _regulator_tracker.blind_spots()
|
| 394 |
+
reward, feedback = compute_regulator_reward(req.predicted_blind_spots, actual)
|
| 395 |
+
return {
|
| 396 |
+
"reward": reward,
|
| 397 |
+
"feedback": feedback,
|
| 398 |
+
"actual_blind_spots": actual,
|
| 399 |
+
"predicted_blind_spots": req.predicted_blind_spots,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@app.post("/regulator/demo_seed")
|
| 404 |
+
def regulator_demo_seed():
|
| 405 |
+
"""Seed the tracker with realistic demo data (phantom_vendor weak at 31%)."""
|
| 406 |
+
_regulator_tracker.reset_for_demo()
|
| 407 |
+
return {"status": "seeded", "report": _regulator_tracker.report()}
|
| 408 |
+
|
| 409 |
+
|
| 410 |
def main():
|
| 411 |
import uvicorn
|
| 412 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Agent Environment for Invoice Processing Pipeline
|
| 3 |
+
=======================================================
|
| 4 |
+
|
| 5 |
+
5 agents with distinct reward signals:
|
| 6 |
+
|
| 7 |
+
Generator — creates clean or fraudulent invoices (adversarial self-play).
|
| 8 |
+
Biases fraud type toward Regulator-detected blind spots.
|
| 9 |
+
|
| 10 |
+
Extractor — extracts structured JSON from raw invoice text.
|
| 11 |
+
4 independent reward signals: format, field_accuracy, math, completeness.
|
| 12 |
+
|
| 13 |
+
Auditor — classifies each invoice as approved/flagged with fraud type.
|
| 14 |
+
+0.99 correct detection, +0.90 clean clearance, +0.01 miss / false positive.
|
| 15 |
+
|
| 16 |
+
Approver — final approve/reject/escalate decision (rule-based threshold).
|
| 17 |
+
|
| 18 |
+
Regulator — cross-episode meta-agent. Monitors Auditor over 30-episode window.
|
| 19 |
+
Detects systematic blind spots. Feeds back to Generator.
|
| 20 |
+
Reward: precision + recall of blind spot predictions.
|
| 21 |
+
|
| 22 |
+
HTTP endpoints (added to app.py):
|
| 23 |
+
POST /multi/reset Start a new multi-agent episode
|
| 24 |
+
POST /multi/extract Score an Extractor submission
|
| 25 |
+
POST /multi/audit Score an Auditor submission + record to tracker
|
| 26 |
+
POST /multi/approve Rule-based Approver decision
|
| 27 |
+
GET /multi/state/{episode_id} Episode state
|
| 28 |
+
GET /regulator/report Current Regulator tracker state
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import collections
|
| 34 |
+
import copy
|
| 35 |
+
import random
|
| 36 |
+
import threading
|
| 37 |
+
import uuid
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Constants
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
FRAUD_TYPES = ["phantom_vendor", "price_gouging", "math_fraud", "duplicate_submission"]
|
| 46 |
+
TRACKER_WINDOW = 30 # episodes in rolling window
|
| 47 |
+
BLIND_SPOT_THRESHOLD = 0.50 # detection rate below this = blind spot
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# AuditorPerformanceTracker — cross-episode singleton
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
class AuditorPerformanceTracker:
|
| 55 |
+
"""
|
| 56 |
+
Thread-safe singleton that tracks Auditor detection rates over the last
|
| 57 |
+
TRACKER_WINDOW episodes. The Regulator reads this to identify blind spots;
|
| 58 |
+
the Generator reads generator_weights() to bias fraud generation.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
_instance: Optional["AuditorPerformanceTracker"] = None
|
| 62 |
+
_class_lock = threading.Lock()
|
| 63 |
+
|
| 64 |
+
def __new__(cls) -> "AuditorPerformanceTracker":
|
| 65 |
+
with cls._class_lock:
|
| 66 |
+
if cls._instance is None:
|
| 67 |
+
obj = super().__new__(cls)
|
| 68 |
+
obj._initialise()
|
| 69 |
+
cls._instance = obj
|
| 70 |
+
return cls._instance
|
| 71 |
+
|
| 72 |
+
def _initialise(self) -> None:
|
| 73 |
+
self._fraud_history: Dict[str, collections.deque] = {
|
| 74 |
+
ft: collections.deque(maxlen=TRACKER_WINDOW) for ft in FRAUD_TYPES
|
| 75 |
+
}
|
| 76 |
+
self._fp_history: collections.deque = collections.deque(maxlen=TRACKER_WINDOW)
|
| 77 |
+
self._total_audits: int = 0
|
| 78 |
+
self._lock = threading.Lock()
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Write path
|
| 82 |
+
|
| 83 |
+
def record_audit(
|
| 84 |
+
self,
|
| 85 |
+
true_fraud_type: Optional[str],
|
| 86 |
+
predicted_verdict: str,
|
| 87 |
+
predicted_fraud_type: Optional[str],
|
| 88 |
+
) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Record one invoice audit result into the rolling window.
|
| 91 |
+
true_fraud_type=None means the invoice was clean (used for FP tracking).
|
| 92 |
+
"""
|
| 93 |
+
with self._lock:
|
| 94 |
+
self._total_audits += 1
|
| 95 |
+
if true_fraud_type is None:
|
| 96 |
+
self._fp_history.append(predicted_verdict == "flagged")
|
| 97 |
+
elif true_fraud_type in self._fraud_history:
|
| 98 |
+
detected = (
|
| 99 |
+
predicted_verdict == "flagged"
|
| 100 |
+
and predicted_fraud_type == true_fraud_type
|
| 101 |
+
)
|
| 102 |
+
self._fraud_history[true_fraud_type].append(detected)
|
| 103 |
+
|
| 104 |
+
# ------------------------------------------------------------------
|
| 105 |
+
# Read path
|
| 106 |
+
|
| 107 |
+
def detection_rates(self) -> Dict[str, Optional[float]]:
|
| 108 |
+
with self._lock:
|
| 109 |
+
return {
|
| 110 |
+
ft: (sum(h) / len(h) if h else None)
|
| 111 |
+
for ft, h in self._fraud_history.items()
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def false_positive_rate(self) -> Optional[float]:
|
| 115 |
+
with self._lock:
|
| 116 |
+
return sum(self._fp_history) / len(self._fp_history) if self._fp_history else None
|
| 117 |
+
|
| 118 |
+
def blind_spots(self, threshold: float = BLIND_SPOT_THRESHOLD) -> List[str]:
|
| 119 |
+
"""Return fraud types where detection rate < threshold (and have data)."""
|
| 120 |
+
rates = self.detection_rates()
|
| 121 |
+
return [ft for ft, rate in rates.items() if rate is not None and rate < threshold]
|
| 122 |
+
|
| 123 |
+
def generator_weights(self) -> Dict[str, float]:
|
| 124 |
+
"""
|
| 125 |
+
Sampling weights for fraud type generation.
|
| 126 |
+
Blind spots share 60% weight; healthy types share 40%.
|
| 127 |
+
Falls back to uniform if no blind spots detected.
|
| 128 |
+
"""
|
| 129 |
+
spots = self.blind_spots()
|
| 130 |
+
if not spots:
|
| 131 |
+
w = 1.0 / len(FRAUD_TYPES)
|
| 132 |
+
return {ft: round(w, 4) for ft in FRAUD_TYPES}
|
| 133 |
+
|
| 134 |
+
n_blind = len(spots)
|
| 135 |
+
n_healthy = len(FRAUD_TYPES) - n_blind
|
| 136 |
+
blind_w = 0.60 / n_blind
|
| 137 |
+
healthy_w = (0.40 / n_healthy) if n_healthy > 0 else 0.0
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
ft: round(blind_w if ft in spots else healthy_w, 4)
|
| 141 |
+
for ft in FRAUD_TYPES
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def report(self) -> Dict[str, Any]:
|
| 145 |
+
rates = self.detection_rates()
|
| 146 |
+
spots = self.blind_spots()
|
| 147 |
+
fp = self.false_positive_rate()
|
| 148 |
+
weights = self.generator_weights()
|
| 149 |
+
|
| 150 |
+
formatted_rates = {}
|
| 151 |
+
for ft in FRAUD_TYPES:
|
| 152 |
+
r = rates[ft]
|
| 153 |
+
status = "no data"
|
| 154 |
+
if r is not None:
|
| 155 |
+
if r < BLIND_SPOT_THRESHOLD:
|
| 156 |
+
status = f"{r:.0%} âš BLIND SPOT"
|
| 157 |
+
else:
|
| 158 |
+
status = f"{r:.0%} ✓ OK"
|
| 159 |
+
formatted_rates[ft] = status
|
| 160 |
+
|
| 161 |
+
fp_str = f"{fp:.0%} ✓ OK" if fp is not None else "no data"
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"total_audits_recorded": self._total_audits,
|
| 165 |
+
"window": TRACKER_WINDOW,
|
| 166 |
+
"detection_rates": formatted_rates,
|
| 167 |
+
"false_positive_rate": fp_str,
|
| 168 |
+
"blind_spots": spots,
|
| 169 |
+
"generator_weights": weights,
|
| 170 |
+
"verdict": (
|
| 171 |
+
f"Recommend retraining on: {', '.join(spots)}"
|
| 172 |
+
if spots
|
| 173 |
+
else "Auditor performance OK across all fraud types"
|
| 174 |
+
),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
def reset_for_demo(self) -> None:
|
| 178 |
+
"""Seed tracker with realistic demo data (for hackathon demo only)."""
|
| 179 |
+
with self._lock:
|
| 180 |
+
self._initialise()
|
| 181 |
+
# Simulate 20 episodes: phantom_vendor weak (31%), others decent
|
| 182 |
+
for _ in range(13):
|
| 183 |
+
self._fraud_history["phantom_vendor"].append(False)
|
| 184 |
+
for _ in range(6):
|
| 185 |
+
self._fraud_history["phantom_vendor"].append(True)
|
| 186 |
+
for _ in range(18):
|
| 187 |
+
self._fraud_history["price_gouging"].append(True)
|
| 188 |
+
for _ in range(6):
|
| 189 |
+
self._fraud_history["price_gouging"].append(False)
|
| 190 |
+
for _ in range(17):
|
| 191 |
+
self._fraud_history["math_fraud"].append(True)
|
| 192 |
+
for _ in range(4):
|
| 193 |
+
self._fraud_history["math_fraud"].append(False)
|
| 194 |
+
for _ in range(15):
|
| 195 |
+
self._fraud_history["duplicate_submission"].append(True)
|
| 196 |
+
for _ in range(7):
|
| 197 |
+
self._fraud_history["duplicate_submission"].append(False)
|
| 198 |
+
for _ in range(2):
|
| 199 |
+
self._fp_history.append(True)
|
| 200 |
+
for _ in range(16):
|
| 201 |
+
self._fp_history.append(False)
|
| 202 |
+
self._total_audits = 20
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Global singleton — imported by app.py
|
| 206 |
+
tracker = AuditorPerformanceTracker()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
# 4 Independent Extractor reward functions
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
def reward_format(extracted: Dict[str, Any]) -> float:
|
| 214 |
+
"""Weight 0.10 — are all 5 required JSON keys present?"""
|
| 215 |
+
required = {"vendor", "date", "currency", "total", "line_items"}
|
| 216 |
+
present = required.intersection(extracted.keys())
|
| 217 |
+
return round(len(present) / len(required) * 0.10, 4)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def reward_field_accuracy(extracted: Dict[str, Any], ground_truth: Dict[str, Any]) -> float:
|
| 221 |
+
"""Weight 0.40 — do vendor/date/currency/total match ground truth?"""
|
| 222 |
+
score = 0.0
|
| 223 |
+
if extracted.get("vendor", "").lower().strip() == ground_truth.get("vendor", "").lower():
|
| 224 |
+
score += 0.10
|
| 225 |
+
if extracted.get("date", "").strip() == ground_truth.get("date", ""):
|
| 226 |
+
score += 0.10
|
| 227 |
+
if extracted.get("currency", "").upper().strip() == ground_truth.get("currency", ""):
|
| 228 |
+
score += 0.05
|
| 229 |
+
try:
|
| 230 |
+
if abs(float(extracted.get("total", 0)) - float(ground_truth.get("total", -1))) < 0.01:
|
| 231 |
+
score += 0.15
|
| 232 |
+
except (ValueError, TypeError):
|
| 233 |
+
pass
|
| 234 |
+
return round(min(score, 0.40), 4)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def reward_math_consistency(extracted: Dict[str, Any]) -> float:
|
| 238 |
+
"""Weight 0.25 — does qty × unit_price = amount for all line items?"""
|
| 239 |
+
items = extracted.get("line_items", [])
|
| 240 |
+
if not isinstance(items, list) or not items:
|
| 241 |
+
return 0.01
|
| 242 |
+
correct = 0
|
| 243 |
+
for item in items:
|
| 244 |
+
try:
|
| 245 |
+
qty = float(item.get("qty", 0))
|
| 246 |
+
up = float(item.get("unit_price", 0))
|
| 247 |
+
amt = float(item.get("amount", -1))
|
| 248 |
+
if abs(qty * up - amt) < 0.02:
|
| 249 |
+
correct += 1
|
| 250 |
+
except (ValueError, TypeError):
|
| 251 |
+
pass
|
| 252 |
+
frac = correct / len(items)
|
| 253 |
+
return round(max(0.01, min(frac * 0.25, 0.25)), 4)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def reward_completeness(extracted: Dict[str, Any], ground_truth: Dict[str, Any]) -> float:
|
| 257 |
+
"""Weight 0.25 — recall: how many expected line items are present?"""
|
| 258 |
+
sub_items = extracted.get("line_items", [])
|
| 259 |
+
gt_items = ground_truth.get("line_items", [])
|
| 260 |
+
if not gt_items:
|
| 261 |
+
return 0.25 if not sub_items else 0.01
|
| 262 |
+
if not isinstance(sub_items, list) or not sub_items:
|
| 263 |
+
return 0.01
|
| 264 |
+
matched = 0
|
| 265 |
+
for gt in gt_items:
|
| 266 |
+
gt_desc = gt.get("description", "").lower()
|
| 267 |
+
for sub in sub_items:
|
| 268 |
+
if gt_desc in sub.get("description", "").lower():
|
| 269 |
+
matched += 1
|
| 270 |
+
break
|
| 271 |
+
frac = matched / len(gt_items)
|
| 272 |
+
return round(max(0.01, min(frac * 0.25, 0.25)), 4)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def combined_extractor_reward(
|
| 276 |
+
extracted: Dict[str, Any],
|
| 277 |
+
ground_truth: Dict[str, Any],
|
| 278 |
+
) -> Tuple[float, Dict[str, float]]:
|
| 279 |
+
"""Compute all 4 signals. Returns (total_reward, breakdown_dict)."""
|
| 280 |
+
f = reward_format(extracted)
|
| 281 |
+
fa = reward_field_accuracy(extracted, ground_truth)
|
| 282 |
+
m = reward_math_consistency(extracted)
|
| 283 |
+
c = reward_completeness(extracted, ground_truth)
|
| 284 |
+
total = round(max(0.01, min(f + fa + m + c, 0.99)), 4)
|
| 285 |
+
return total, {
|
| 286 |
+
"format": f,
|
| 287 |
+
"field_accuracy": fa,
|
| 288 |
+
"math_consistency": m,
|
| 289 |
+
"completeness": c,
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
# Auditor reward
|
| 295 |
+
# ---------------------------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
def compute_auditor_reward(
|
| 298 |
+
predicted_verdict: str,
|
| 299 |
+
predicted_fraud_type: Optional[str],
|
| 300 |
+
true_verdict: str,
|
| 301 |
+
true_fraud_type: Optional[str],
|
| 302 |
+
) -> Tuple[float, str]:
|
| 303 |
+
"""
|
| 304 |
+
+0.99 correct fraud detection (right verdict + right type)
|
| 305 |
+
+0.90 correct clean clearance
|
| 306 |
+
+0.50 flagged fraud but wrong type
|
| 307 |
+
+0.01 miss or false positive
|
| 308 |
+
"""
|
| 309 |
+
is_fraud = true_verdict == "flagged"
|
| 310 |
+
pred_flagged = predicted_verdict == "flagged"
|
| 311 |
+
|
| 312 |
+
if is_fraud and pred_flagged:
|
| 313 |
+
if predicted_fraud_type == true_fraud_type:
|
| 314 |
+
return 0.99, f"Correct: {true_fraud_type} detected"
|
| 315 |
+
return 0.50, f"Flagged but wrong type (expected {true_fraud_type}, got {predicted_fraud_type})"
|
| 316 |
+
elif not is_fraud and not pred_flagged:
|
| 317 |
+
return 0.90, "Correct: clean invoice approved"
|
| 318 |
+
elif not is_fraud and pred_flagged:
|
| 319 |
+
return 0.01, f"False positive: clean invoice flagged as {predicted_fraud_type}"
|
| 320 |
+
else:
|
| 321 |
+
return 0.01, f"Missed fraud: {true_fraud_type} not detected"
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ---------------------------------------------------------------------------
|
| 325 |
+
# Generator reward (adversarial self-play)
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def compute_generator_reward(auditor_detected: bool, approver_approved: bool) -> float:
|
| 329 |
+
"""
|
| 330 |
+
Fraud fully undetected + approved → 0.85
|
| 331 |
+
Auditor missed but Approver caught → 0.60
|
| 332 |
+
Auditor caught it → 0.10
|
| 333 |
+
"""
|
| 334 |
+
if not auditor_detected and approver_approved:
|
| 335 |
+
return 0.85
|
| 336 |
+
elif not auditor_detected and not approver_approved:
|
| 337 |
+
return 0.60
|
| 338 |
+
return 0.10
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
# Regulator reward
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
|
| 345 |
+
def compute_regulator_reward(
|
| 346 |
+
predicted_blind_spots: List[str],
|
| 347 |
+
actual_blind_spots: List[str],
|
| 348 |
+
) -> Tuple[float, str]:
|
| 349 |
+
"""Precision (0.40) + recall (0.40) + no-over-flag bonus (0.20)."""
|
| 350 |
+
if not actual_blind_spots and not predicted_blind_spots:
|
| 351 |
+
return 0.99, "Correctly predicted no blind spots"
|
| 352 |
+
if not actual_blind_spots:
|
| 353 |
+
return 0.01, "False alarm: predicted blind spots when none exist"
|
| 354 |
+
if not predicted_blind_spots:
|
| 355 |
+
return 0.01, "Missed all blind spots"
|
| 356 |
+
|
| 357 |
+
correct = set(predicted_blind_spots) & set(actual_blind_spots)
|
| 358 |
+
prec = len(correct) / len(predicted_blind_spots)
|
| 359 |
+
rec = len(correct) / len(actual_blind_spots)
|
| 360 |
+
no_over_flag = 1.0 if prec >= 0.5 else 0.0
|
| 361 |
+
score = round(max(0.01, min(0.40 * prec + 0.40 * rec + 0.20 * no_over_flag, 0.99)), 4)
|
| 362 |
+
return score, f"Blind spot prediction: precision={prec:.2f}, recall={rec:.2f}"
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ---------------------------------------------------------------------------
|
| 366 |
+
# Approver (rule-based)
|
| 367 |
+
# ---------------------------------------------------------------------------
|
| 368 |
+
|
| 369 |
+
def approver_decision(
|
| 370 |
+
auditor_verdict: str,
|
| 371 |
+
auditor_confidence: float,
|
| 372 |
+
auditor_fraud_type: Optional[str],
|
| 373 |
+
) -> Dict[str, Any]:
|
| 374 |
+
"""
|
| 375 |
+
Simple rule-based Approver.
|
| 376 |
+
HIGH confidence flag → reject
|
| 377 |
+
MEDIUM confidence flag → escalate
|
| 378 |
+
LOW confidence flag → escalate
|
| 379 |
+
Approved → approve
|
| 380 |
+
"""
|
| 381 |
+
if auditor_verdict != "flagged":
|
| 382 |
+
return {"decision": "approve", "reason": "Auditor cleared invoice"}
|
| 383 |
+
|
| 384 |
+
if auditor_confidence >= 0.80:
|
| 385 |
+
return {
|
| 386 |
+
"decision": "reject",
|
| 387 |
+
"reason": f"High-confidence {auditor_fraud_type} fraud detected ({auditor_confidence:.0%})",
|
| 388 |
+
}
|
| 389 |
+
elif auditor_confidence >= 0.50:
|
| 390 |
+
return {
|
| 391 |
+
"decision": "escalate",
|
| 392 |
+
"reason": f"Medium-confidence {auditor_fraud_type} flag — needs human review",
|
| 393 |
+
}
|
| 394 |
+
else:
|
| 395 |
+
return {
|
| 396 |
+
"decision": "escalate",
|
| 397 |
+
"reason": f"Low-confidence flag on {auditor_fraud_type} — needs human review",
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
# Biased invoice generator (uses tracker weights)
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
def _generate_expert_batch_biased(
|
| 406 |
+
fraud_weights: Optional[Dict[str, float]] = None,
|
| 407 |
+
) -> Tuple[List[Dict], List[Dict], str]:
|
| 408 |
+
"""
|
| 409 |
+
Generate an expert fraud audit batch with fraud type sampling biased
|
| 410 |
+
by the Regulator's generator_weights().
|
| 411 |
+
|
| 412 |
+
Returns (invoices, ground_truth_list, reference_text).
|
| 413 |
+
Reuses generation helpers from environment.py.
|
| 414 |
+
"""
|
| 415 |
+
import sys, os
|
| 416 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 417 |
+
from server.environment import (
|
| 418 |
+
_generate_invoice, _render_expert_batch, _render_expert_reference,
|
| 419 |
+
PHANTOM_VENDORS, MARKET_PRICE_MAX, VENDORS,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
if fraud_weights is None:
|
| 423 |
+
fraud_weights = tracker.generator_weights()
|
| 424 |
+
|
| 425 |
+
n_invoices = random.randint(4, 6)
|
| 426 |
+
n_fraudulent = random.randint(2, 3)
|
| 427 |
+
|
| 428 |
+
all_indices = list(range(n_invoices))
|
| 429 |
+
random.shuffle(all_indices)
|
| 430 |
+
fraud_indices = set(all_indices[:n_fraudulent])
|
| 431 |
+
|
| 432 |
+
# Weighted fraud type selection
|
| 433 |
+
types_pool = list(fraud_weights.keys())
|
| 434 |
+
weights_pool = [fraud_weights[ft] for ft in types_pool]
|
| 435 |
+
chosen_fraud_types = random.choices(types_pool, weights=weights_pool, k=n_fraudulent)
|
| 436 |
+
fraud_type_map = {idx: chosen_fraud_types[i] for i, idx in enumerate(list(fraud_indices))}
|
| 437 |
+
|
| 438 |
+
invoices: List[Dict] = []
|
| 439 |
+
ground_truth: List[Dict] = []
|
| 440 |
+
invoice_history: List[Dict] = []
|
| 441 |
+
|
| 442 |
+
for _ in range(3):
|
| 443 |
+
invoice_history.append(_generate_invoice())
|
| 444 |
+
|
| 445 |
+
for i in range(n_invoices):
|
| 446 |
+
inv = _generate_invoice()
|
| 447 |
+
|
| 448 |
+
if i in fraud_indices:
|
| 449 |
+
ftype = fraud_type_map[i]
|
| 450 |
+
|
| 451 |
+
if ftype == "phantom_vendor":
|
| 452 |
+
inv["vendor"] = random.choice(PHANTOM_VENDORS)
|
| 453 |
+
|
| 454 |
+
elif ftype == "price_gouging":
|
| 455 |
+
item = random.choice(inv["line_items"])
|
| 456 |
+
market_max = MARKET_PRICE_MAX.get(item["description"], item["unit_price"])
|
| 457 |
+
item["unit_price"] = round(market_max * random.uniform(1.6, 2.2), 2)
|
| 458 |
+
item["amount"] = round(item["qty"] * item["unit_price"], 2)
|
| 459 |
+
inv["total"] = round(sum(it["amount"] for it in inv["line_items"]), 2)
|
| 460 |
+
|
| 461 |
+
elif ftype == "duplicate_submission":
|
| 462 |
+
inv = copy.deepcopy(random.choice(invoice_history))
|
| 463 |
+
|
| 464 |
+
elif ftype == "math_fraud":
|
| 465 |
+
real_total = round(sum(it["amount"] for it in inv["line_items"]), 2)
|
| 466 |
+
inv["total"] = round(real_total * random.uniform(1.08, 1.18), 2)
|
| 467 |
+
|
| 468 |
+
ground_truth.append({
|
| 469 |
+
"invoice_id": inv["invoice_id"],
|
| 470 |
+
"verdict": "flagged",
|
| 471 |
+
"fraud_type": ftype,
|
| 472 |
+
})
|
| 473 |
+
else:
|
| 474 |
+
invoice_history.append(inv)
|
| 475 |
+
ground_truth.append({
|
| 476 |
+
"invoice_id": inv["invoice_id"],
|
| 477 |
+
"verdict": "approved",
|
| 478 |
+
"fraud_type": None,
|
| 479 |
+
})
|
| 480 |
+
|
| 481 |
+
invoices.append(inv)
|
| 482 |
+
|
| 483 |
+
reference_text = _render_expert_reference(invoice_history)
|
| 484 |
+
raw_text = _render_expert_batch(invoices)
|
| 485 |
+
return invoices, ground_truth, raw_text, reference_text
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# ---------------------------------------------------------------------------
|
| 489 |
+
# MultiAgentEpisode data class
|
| 490 |
+
# ---------------------------------------------------------------------------
|
| 491 |
+
|
| 492 |
+
@dataclass
|
| 493 |
+
class MultiAgentEpisode:
|
| 494 |
+
episode_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 495 |
+
invoices: List[Dict[str, Any]] = field(default_factory=list)
|
| 496 |
+
ground_truth: List[Dict[str, Any]] = field(default_factory=list)
|
| 497 |
+
raw_text: str = ""
|
| 498 |
+
reference_data: str = ""
|
| 499 |
+
fraud_weights_used: Dict[str, float] = field(default_factory=dict)
|
| 500 |
+
|
| 501 |
+
# Extractor stage
|
| 502 |
+
extractor_result: Optional[Dict[str, Any]] = None
|
| 503 |
+
extractor_reward: float = 0.0
|
| 504 |
+
extractor_breakdown: Dict[str, float] = field(default_factory=dict)
|
| 505 |
+
|
| 506 |
+
# Auditor stage
|
| 507 |
+
auditor_results: List[Dict[str, Any]] = field(default_factory=list)
|
| 508 |
+
auditor_rewards: List[float] = field(default_factory=list)
|
| 509 |
+
mean_auditor_reward: float = 0.0
|
| 510 |
+
|
| 511 |
+
# Approver stage
|
| 512 |
+
approver_results: List[Dict[str, Any]] = field(default_factory=list)
|
| 513 |
+
|
| 514 |
+
# Generator reward (computed after full pipeline)
|
| 515 |
+
generator_rewards: List[float] = field(default_factory=list)
|
| 516 |
+
mean_generator_reward: float = 0.0
|
| 517 |
+
|
| 518 |
+
done: bool = False
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# ---------------------------------------------------------------------------
|
| 522 |
+
# Session registry for multi-agent episodes
|
| 523 |
+
# ---------------------------------------------------------------------------
|
| 524 |
+
|
| 525 |
+
_MAX_MULTI_SESSIONS = 100
|
| 526 |
+
_multi_sessions: "collections.OrderedDict[str, MultiAgentEpisode]" = collections.OrderedDict()
|
| 527 |
+
_multi_lock = threading.Lock()
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def create_episode() -> MultiAgentEpisode:
|
| 531 |
+
"""Create a new multi-agent episode with Regulator-biased Generator."""
|
| 532 |
+
weights = tracker.generator_weights()
|
| 533 |
+
invoices, ground_truth, raw_text, reference_data = _generate_expert_batch_biased(weights)
|
| 534 |
+
|
| 535 |
+
ep = MultiAgentEpisode(
|
| 536 |
+
invoices=invoices,
|
| 537 |
+
ground_truth=ground_truth,
|
| 538 |
+
raw_text=raw_text,
|
| 539 |
+
reference_data=reference_data,
|
| 540 |
+
fraud_weights_used=weights,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
with _multi_lock:
|
| 544 |
+
_multi_sessions[ep.episode_id] = ep
|
| 545 |
+
while len(_multi_sessions) > _MAX_MULTI_SESSIONS:
|
| 546 |
+
_multi_sessions.popitem(last=False)
|
| 547 |
+
|
| 548 |
+
return ep
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def get_episode(episode_id: str) -> Optional[MultiAgentEpisode]:
|
| 552 |
+
with _multi_lock:
|
| 553 |
+
return _multi_sessions.get(episode_id)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# ---------------------------------------------------------------------------
|
| 557 |
+
# Stage handlers (called by HTTP endpoints)
|
| 558 |
+
# ---------------------------------------------------------------------------
|
| 559 |
+
|
| 560 |
+
def handle_extract(
|
| 561 |
+
episode_id: str,
|
| 562 |
+
extracted_data: Dict[str, Any],
|
| 563 |
+
) -> Dict[str, Any]:
|
| 564 |
+
"""
|
| 565 |
+
Score Extractor output against the first invoice ground truth.
|
| 566 |
+
Returns reward + breakdown.
|
| 567 |
+
"""
|
| 568 |
+
ep = get_episode(episode_id)
|
| 569 |
+
if ep is None:
|
| 570 |
+
return {"error": "Episode not found. Call /multi/reset first."}
|
| 571 |
+
|
| 572 |
+
# Use first clean invoice as reference for extraction grading
|
| 573 |
+
# (the expert task expects audit, but extraction is graded on the first invoice)
|
| 574 |
+
gt = ep.invoices[0] if ep.invoices else {}
|
| 575 |
+
total, breakdown = combined_extractor_reward(extracted_data, gt)
|
| 576 |
+
|
| 577 |
+
ep.extractor_result = extracted_data
|
| 578 |
+
ep.extractor_reward = total
|
| 579 |
+
ep.extractor_breakdown = breakdown
|
| 580 |
+
|
| 581 |
+
return {
|
| 582 |
+
"episode_id": episode_id,
|
| 583 |
+
"reward": total,
|
| 584 |
+
"breakdown": breakdown,
|
| 585 |
+
"feedback": (
|
| 586 |
+
f"Extractor: format={breakdown['format']:.2f}, "
|
| 587 |
+
f"field={breakdown['field_accuracy']:.2f}, "
|
| 588 |
+
f"math={breakdown['math_consistency']:.2f}, "
|
| 589 |
+
f"completeness={breakdown['completeness']:.2f}"
|
| 590 |
+
),
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def handle_audit(
|
| 595 |
+
episode_id: str,
|
| 596 |
+
audit_results: List[Dict[str, Any]],
|
| 597 |
+
) -> Dict[str, Any]:
|
| 598 |
+
"""
|
| 599 |
+
Score Auditor output. Records results to AuditorPerformanceTracker.
|
| 600 |
+
audit_results: [{"invoice_id": str, "verdict": str, "fraud_type": str|None, "confidence": float}]
|
| 601 |
+
"""
|
| 602 |
+
ep = get_episode(episode_id)
|
| 603 |
+
if ep is None:
|
| 604 |
+
return {"error": "Episode not found. Call /multi/reset first."}
|
| 605 |
+
|
| 606 |
+
gt_map = {gt["invoice_id"]: gt for gt in ep.ground_truth}
|
| 607 |
+
rewards = []
|
| 608 |
+
feedbacks = []
|
| 609 |
+
approver_inputs = []
|
| 610 |
+
|
| 611 |
+
for result in audit_results:
|
| 612 |
+
inv_id = result.get("invoice_id", "")
|
| 613 |
+
pred_verdict = result.get("verdict", "approved").lower()
|
| 614 |
+
pred_ftype = result.get("fraud_type")
|
| 615 |
+
confidence = float(result.get("confidence", 0.5))
|
| 616 |
+
|
| 617 |
+
gt = gt_map.get(inv_id)
|
| 618 |
+
if gt is None:
|
| 619 |
+
feedbacks.append(f"{inv_id}: not found in episode")
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
true_verdict = gt["verdict"]
|
| 623 |
+
true_ftype = gt["fraud_type"]
|
| 624 |
+
|
| 625 |
+
reward, fb = compute_auditor_reward(pred_verdict, pred_ftype, true_verdict, true_ftype)
|
| 626 |
+
rewards.append(reward)
|
| 627 |
+
feedbacks.append(f"{inv_id}: {fb}")
|
| 628 |
+
|
| 629 |
+
# Record to global tracker
|
| 630 |
+
tracker.record_audit(true_ftype, pred_verdict, pred_ftype)
|
| 631 |
+
|
| 632 |
+
approver_inputs.append({
|
| 633 |
+
"invoice_id": inv_id,
|
| 634 |
+
"auditor_verdict": pred_verdict,
|
| 635 |
+
"auditor_confidence": confidence,
|
| 636 |
+
"auditor_fraud_type": pred_ftype,
|
| 637 |
+
})
|
| 638 |
+
|
| 639 |
+
mean_reward = round(sum(rewards) / len(rewards), 4) if rewards else 0.01
|
| 640 |
+
ep.auditor_results = audit_results
|
| 641 |
+
ep.auditor_rewards = rewards
|
| 642 |
+
ep.mean_auditor_reward = mean_reward
|
| 643 |
+
ep.approver_results = approver_inputs # stage input ready
|
| 644 |
+
|
| 645 |
+
return {
|
| 646 |
+
"episode_id": episode_id,
|
| 647 |
+
"mean_reward": mean_reward,
|
| 648 |
+
"per_invoice_rewards": dict(zip([r.get("invoice_id", i) for i, r in enumerate(audit_results)], rewards)),
|
| 649 |
+
"feedback": "; ".join(feedbacks),
|
| 650 |
+
"tracker_report": tracker.report(),
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def handle_approve(episode_id: str) -> Dict[str, Any]:
|
| 655 |
+
"""
|
| 656 |
+
Run rule-based Approver on Auditor results. Computes Generator reward.
|
| 657 |
+
"""
|
| 658 |
+
ep = get_episode(episode_id)
|
| 659 |
+
if ep is None:
|
| 660 |
+
return {"error": "Episode not found"}
|
| 661 |
+
if not ep.approver_results:
|
| 662 |
+
return {"error": "Run /multi/audit before /multi/approve"}
|
| 663 |
+
|
| 664 |
+
decisions = []
|
| 665 |
+
gen_rewards = []
|
| 666 |
+
gt_map = {gt["invoice_id"]: gt for gt in ep.ground_truth}
|
| 667 |
+
|
| 668 |
+
for inp in ep.approver_results:
|
| 669 |
+
inv_id = inp["invoice_id"]
|
| 670 |
+
decision = approver_decision(
|
| 671 |
+
inp["auditor_verdict"],
|
| 672 |
+
inp["auditor_confidence"],
|
| 673 |
+
inp["auditor_fraud_type"],
|
| 674 |
+
)
|
| 675 |
+
decisions.append({"invoice_id": inv_id, **decision})
|
| 676 |
+
|
| 677 |
+
# Generator reward for fraud invoices
|
| 678 |
+
gt = gt_map.get(inv_id, {})
|
| 679 |
+
if gt.get("verdict") == "flagged":
|
| 680 |
+
auditor_detected = inp["auditor_verdict"] == "flagged"
|
| 681 |
+
approver_approved = decision["decision"] == "approve"
|
| 682 |
+
gen_rewards.append(compute_generator_reward(auditor_detected, approver_approved))
|
| 683 |
+
|
| 684 |
+
mean_gen = round(sum(gen_rewards) / len(gen_rewards), 4) if gen_rewards else 0.0
|
| 685 |
+
ep.generator_rewards = gen_rewards
|
| 686 |
+
ep.mean_generator_reward = mean_gen
|
| 687 |
+
ep.done = True
|
| 688 |
+
|
| 689 |
+
return {
|
| 690 |
+
"episode_id": episode_id,
|
| 691 |
+
"decisions": decisions,
|
| 692 |
+
"generator_reward": mean_gen,
|
| 693 |
+
"feedback": (
|
| 694 |
+
f"Approver processed {len(decisions)} invoices. "
|
| 695 |
+
f"Generator adversarial reward: {mean_gen:.3f}"
|
| 696 |
+
),
|
| 697 |
+
}
|
|
@@ -101,6 +101,40 @@ TASK_DESCRIPTIONS = {
|
|
| 101 |
PLACEHOLDER_JSON = "// Reset an episode first, then paste or generate JSON here."
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def build_ui() -> gr.Blocks:
|
| 105 |
|
| 106 |
# ---- State per Gradio session ----------------------------------------
|
|
@@ -212,119 +246,146 @@ def build_ui() -> gr.Blocks:
|
|
| 212 |
|
| 213 |
session_state = gr.State(init_state)
|
| 214 |
|
| 215 |
-
|
| 216 |
-
with gr.Row():
|
| 217 |
-
task_dd = gr.Dropdown(
|
| 218 |
-
choices=list(TASK_DESCRIPTIONS.keys()),
|
| 219 |
-
value="easy",
|
| 220 |
-
label="Task",
|
| 221 |
-
scale=1,
|
| 222 |
-
)
|
| 223 |
-
reset_btn = gr.Button("🔄 Reset Episode", variant="primary", scale=1)
|
| 224 |
-
status_box = gr.Textbox(
|
| 225 |
-
label="Status",
|
| 226 |
-
interactive=False,
|
| 227 |
-
scale=3,
|
| 228 |
-
lines=2,
|
| 229 |
-
)
|
| 230 |
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
)
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
label="Extracted JSON",
|
| 254 |
-
language="json",
|
| 255 |
-
lines=16,
|
| 256 |
-
value=PLACEHOLDER_JSON,
|
| 257 |
-
)
|
| 258 |
with gr.Row():
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
interactive=False,
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
-
|
| 265 |
-
"
|
| 266 |
-
|
|
|
|
| 267 |
interactive=False,
|
|
|
|
| 268 |
)
|
| 269 |
-
|
| 270 |
-
|
|
|
|
| 271 |
interactive=False,
|
| 272 |
-
lines=
|
| 273 |
)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
scale=3,
|
| 282 |
-
)
|
| 283 |
-
breakdown_box = gr.Code(
|
| 284 |
-
label="Reward Breakdown",
|
| 285 |
-
language="json",
|
| 286 |
-
lines=5,
|
| 287 |
-
interactive=False,
|
| 288 |
-
scale=2,
|
| 289 |
-
)
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
)
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
session_state, status_box, task_info,
|
| 310 |
-
invoice_box, ref_box, json_box,
|
| 311 |
-
feedback_box, history_box,
|
| 312 |
-
llm_btn, submit_btn,
|
| 313 |
-
],
|
| 314 |
-
)
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
outputs=[json_box, llm_status],
|
| 321 |
-
)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
return demo
|
|
|
|
| 101 |
PLACEHOLDER_JSON = "// Reset an episode first, then paste or generate JSON here."
|
| 102 |
|
| 103 |
|
| 104 |
+
def _get_regulator_report() -> str:
|
| 105 |
+
data = _get("/regulator/report")
|
| 106 |
+
if "error" in data:
|
| 107 |
+
return f"Error: {data['error']}"
|
| 108 |
+
lines = [
|
| 109 |
+
f"Total audits recorded: {data.get('total_audits_recorded', 0)} (window={data.get('window', 30)})",
|
| 110 |
+
"",
|
| 111 |
+
"FRAUD TYPE DETECTION RATES",
|
| 112 |
+
"─" * 40,
|
| 113 |
+
]
|
| 114 |
+
for ft, status in data.get("detection_rates", {}).items():
|
| 115 |
+
lines.append(f" {ft:<28} {status}")
|
| 116 |
+
lines += [
|
| 117 |
+
"",
|
| 118 |
+
f"False Positive Rate: {data.get('false_positive_rate', 'no data')}",
|
| 119 |
+
"",
|
| 120 |
+
f"BLIND SPOTS: {data.get('blind_spots', [])}",
|
| 121 |
+
"",
|
| 122 |
+
"GENERATOR WEIGHTS (next episode)",
|
| 123 |
+
"─" * 40,
|
| 124 |
+
]
|
| 125 |
+
for ft, w in data.get("generator_weights", {}).items():
|
| 126 |
+
lines.append(f" {ft:<28} {w:.3f}")
|
| 127 |
+
lines += ["", f"VERDICT: {data.get('verdict', '')}"]
|
| 128 |
+
return "\n".join(lines)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _seed_demo_data() -> str:
|
| 132 |
+
data = _post("/regulator/demo_seed", {})
|
| 133 |
+
if "error" in data:
|
| 134 |
+
return f"Error: {data['error']}"
|
| 135 |
+
return "✅ Demo data seeded — phantom_vendor at ~31% (blind spot)\n\n" + _get_regulator_report()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
def build_ui() -> gr.Blocks:
|
| 139 |
|
| 140 |
# ---- State per Gradio session ----------------------------------------
|
|
|
|
| 246 |
|
| 247 |
session_state = gr.State(init_state)
|
| 248 |
|
| 249 |
+
with gr.Tabs():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
# ================================================================
|
| 252 |
+
# Tab 1 — Agent Tester
|
| 253 |
+
# ================================================================
|
| 254 |
+
with gr.Tab("Agent Tester"):
|
| 255 |
|
| 256 |
+
# --- Controls row -----------------------------------------
|
| 257 |
+
with gr.Row():
|
| 258 |
+
task_dd = gr.Dropdown(
|
| 259 |
+
choices=list(TASK_DESCRIPTIONS.keys()),
|
| 260 |
+
value="easy",
|
| 261 |
+
label="Task",
|
| 262 |
+
scale=1,
|
| 263 |
+
)
|
| 264 |
+
reset_btn = gr.Button("🔄 Reset Episode", variant="primary", scale=1)
|
| 265 |
+
status_box = gr.Textbox(
|
| 266 |
+
label="Status",
|
| 267 |
+
interactive=False,
|
| 268 |
+
scale=3,
|
| 269 |
+
lines=2,
|
| 270 |
+
)
|
|
|
|
| 271 |
|
| 272 |
+
task_info = gr.Textbox(label="Task Description", interactive=False, lines=1)
|
| 273 |
+
|
| 274 |
+
# --- Main two-column layout --------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
with gr.Row():
|
| 276 |
+
with gr.Column(scale=5):
|
| 277 |
+
invoice_box = gr.Textbox(
|
| 278 |
+
label="Invoice Data (raw text)",
|
| 279 |
+
interactive=False,
|
| 280 |
+
lines=16,
|
| 281 |
+
max_lines=30,
|
| 282 |
+
)
|
| 283 |
+
ref_box = gr.Textbox(
|
| 284 |
+
label="Reference Data (PO / vendor registry / catalog)",
|
| 285 |
+
interactive=False,
|
| 286 |
+
lines=8,
|
| 287 |
+
max_lines=16,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
with gr.Column(scale=5):
|
| 291 |
+
json_box = gr.Code(
|
| 292 |
+
label="Extracted JSON",
|
| 293 |
+
language="json",
|
| 294 |
+
lines=16,
|
| 295 |
+
value=PLACEHOLDER_JSON,
|
| 296 |
+
)
|
| 297 |
+
with gr.Row():
|
| 298 |
+
llm_btn = gr.Button(
|
| 299 |
+
"🤖 Run LLM Agent",
|
| 300 |
+
variant="secondary",
|
| 301 |
+
interactive=False,
|
| 302 |
+
)
|
| 303 |
+
submit_btn = gr.Button(
|
| 304 |
+
"✅ Submit",
|
| 305 |
+
variant="primary",
|
| 306 |
+
interactive=False,
|
| 307 |
+
)
|
| 308 |
+
llm_status = gr.Textbox(
|
| 309 |
+
label="LLM status",
|
| 310 |
+
interactive=False,
|
| 311 |
+
lines=1,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# --- Results row ------------------------------------------
|
| 315 |
+
with gr.Row():
|
| 316 |
+
feedback_box = gr.Textbox(
|
| 317 |
+
label="Grader Feedback",
|
| 318 |
interactive=False,
|
| 319 |
+
lines=5,
|
| 320 |
+
scale=3,
|
| 321 |
)
|
| 322 |
+
breakdown_box = gr.Code(
|
| 323 |
+
label="Reward Breakdown",
|
| 324 |
+
language="json",
|
| 325 |
+
lines=5,
|
| 326 |
interactive=False,
|
| 327 |
+
scale=2,
|
| 328 |
)
|
| 329 |
+
|
| 330 |
+
history_box = gr.Textbox(
|
| 331 |
+
label="Step History",
|
| 332 |
interactive=False,
|
| 333 |
+
lines=3,
|
| 334 |
)
|
| 335 |
|
| 336 |
+
# --- Wiring -----------------------------------------------
|
| 337 |
+
task_dd.change(
|
| 338 |
+
fn=lambda t: TASK_DESCRIPTIONS.get(t, ""),
|
| 339 |
+
inputs=[task_dd],
|
| 340 |
+
outputs=[task_info],
|
| 341 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
+
reset_btn.click(
|
| 344 |
+
fn=do_reset,
|
| 345 |
+
inputs=[task_dd, session_state],
|
| 346 |
+
outputs=[
|
| 347 |
+
session_state, status_box, task_info,
|
| 348 |
+
invoice_box, ref_box, json_box,
|
| 349 |
+
feedback_box, history_box,
|
| 350 |
+
llm_btn, submit_btn,
|
| 351 |
+
],
|
| 352 |
+
)
|
| 353 |
|
| 354 |
+
llm_btn.click(
|
| 355 |
+
fn=do_llm,
|
| 356 |
+
inputs=[task_dd, session_state],
|
| 357 |
+
outputs=[json_box, llm_status],
|
| 358 |
+
)
|
|
|
|
| 359 |
|
| 360 |
+
submit_btn.click(
|
| 361 |
+
fn=do_submit,
|
| 362 |
+
inputs=[json_box, session_state],
|
| 363 |
+
outputs=[session_state, status_box, feedback_box, history_box, breakdown_box],
|
| 364 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
# ================================================================
|
| 367 |
+
# Tab 2 — Regulator Dashboard
|
| 368 |
+
# ================================================================
|
| 369 |
+
with gr.Tab("Regulator Dashboard"):
|
|
|
|
|
|
|
| 370 |
|
| 371 |
+
gr.Markdown(
|
| 372 |
+
"## Regulator — Cross-Episode Auditor Oversight\n"
|
| 373 |
+
"Monitors Auditor detection rates over 30 episodes. "
|
| 374 |
+
"Detects blind spots and biases the Generator toward under-detected fraud types."
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
with gr.Row():
|
| 378 |
+
refresh_btn = gr.Button("🔄 Refresh Report", variant="primary")
|
| 379 |
+
seed_btn = gr.Button("🌱 Seed Demo Data", variant="secondary")
|
| 380 |
+
|
| 381 |
+
report_box = gr.Textbox(
|
| 382 |
+
label="Regulator Report",
|
| 383 |
+
interactive=False,
|
| 384 |
+
lines=22,
|
| 385 |
+
value="Click 'Refresh Report' or 'Seed Demo Data' to load.",
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
refresh_btn.click(fn=_get_regulator_report, inputs=[], outputs=[report_box])
|
| 389 |
+
seed_btn.click(fn=_seed_demo_data, inputs=[], outputs=[report_box])
|
| 390 |
|
| 391 |
return demo
|