Spaces:
Sleeping
Sleeping
Commit ·
e97acd1
1
Parent(s): 9672a3e
Add training-first RL architecture with tracking
Browse files- configs/colab_qlora_grpo.json +4 -4
- notebooks/fraudshield_trl_colab.ipynb +94 -32
- train.py +106 -4
configs/colab_qlora_grpo.json
CHANGED
|
@@ -24,13 +24,13 @@
|
|
| 24 |
"checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
|
| 25 |
"save_to_drive": true,
|
| 26 |
"drive_dir": "/content/drive/MyDrive/fraudshield",
|
| 27 |
-
"num_train_epochs":
|
| 28 |
"per_device_train_batch_size": 2,
|
| 29 |
"gradient_accumulation_steps": 4,
|
| 30 |
-
"learning_rate":
|
| 31 |
"eval_every_steps": 10,
|
| 32 |
"save_every_steps": 20,
|
| 33 |
-
"warmstart_rollouts_per_task":
|
| 34 |
"rl_rollouts_per_task": 8,
|
| 35 |
"max_prompt_tokens": 2048,
|
| 36 |
"max_completion_tokens": 220,
|
|
@@ -41,7 +41,7 @@
|
|
| 41 |
"run_name": "fraudshield-colab-run",
|
| 42 |
"resume_from_checkpoint": null,
|
| 43 |
"public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
|
| 44 |
-
"public_curriculum_rows":
|
| 45 |
},
|
| 46 |
"evaluation": {
|
| 47 |
"tasks": [
|
|
|
|
| 24 |
"checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
|
| 25 |
"save_to_drive": true,
|
| 26 |
"drive_dir": "/content/drive/MyDrive/fraudshield",
|
| 27 |
+
"num_train_epochs": 3,
|
| 28 |
"per_device_train_batch_size": 2,
|
| 29 |
"gradient_accumulation_steps": 4,
|
| 30 |
+
"learning_rate": 5e-05,
|
| 31 |
"eval_every_steps": 10,
|
| 32 |
"save_every_steps": 20,
|
| 33 |
+
"warmstart_rollouts_per_task": 60,
|
| 34 |
"rl_rollouts_per_task": 8,
|
| 35 |
"max_prompt_tokens": 2048,
|
| 36 |
"max_completion_tokens": 220,
|
|
|
|
| 41 |
"run_name": "fraudshield-colab-run",
|
| 42 |
"resume_from_checkpoint": null,
|
| 43 |
"public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
|
| 44 |
+
"public_curriculum_rows": 500
|
| 45 |
},
|
| 46 |
"evaluation": {
|
| 47 |
"tasks": [
|
notebooks/fraudshield_trl_colab.ipynb
CHANGED
|
@@ -8,14 +8,25 @@
|
|
| 8 |
"source": [
|
| 9 |
"# FraudShield Colab Training Notebook\n",
|
| 10 |
"\n",
|
| 11 |
-
"This notebook
|
| 12 |
"\n",
|
| 13 |
-
"
|
| 14 |
-
"
|
|
|
|
|
|
|
| 15 |
"\n",
|
|
|
|
| 16 |
"The goal is to learn more than a static heuristic by giving the model broader fraud signals first, then teaching it how to act inside the FraudShield workflow.\n"
|
| 17 |
],
|
| 18 |
"id": "Wadw-uDzhxuI"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
},
|
| 20 |
{
|
| 21 |
"cell_type": "code",
|
|
@@ -98,18 +109,24 @@
|
|
| 98 |
],
|
| 99 |
"source": [
|
| 100 |
"%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
|
| 101 |
-
"%pip install -q
|
| 102 |
-
"%pip install -q
|
|
|
|
| 103 |
"%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 104 |
"\n",
|
| 105 |
"%cd /content\n",
|
| 106 |
"!rm -rf Fraudshield\n",
|
| 107 |
"!git clone https://github.com/DevikaJ2005/Fraudshield.git\n",
|
| 108 |
"%cd /content/Fraudshield\n",
|
|
|
|
| 109 |
"!ls\n",
|
| 110 |
"%pip install -q -e .\n"
|
| 111 |
],
|
| 112 |
"id": "yqcGck2nhxuN"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
},
|
| 114 |
{
|
| 115 |
"cell_type": "code",
|
|
@@ -121,18 +138,22 @@
|
|
| 121 |
"source": [
|
| 122 |
"import os\n",
|
| 123 |
"from getpass import getpass\n",
|
| 124 |
-
"\n",
|
| 125 |
"from huggingface_hub import login\n",
|
| 126 |
"\n",
|
| 127 |
-
"token = getpass('Enter your HF token (
|
| 128 |
"if token.strip():\n",
|
| 129 |
" os.environ['HF_TOKEN'] = token.strip()\n",
|
| 130 |
" login(token=token.strip())\n",
|
| 131 |
" print('HF login completed.')\n",
|
| 132 |
"else:\n",
|
|
|
|
| 133 |
" print('Skipping HF login for now.')\n"
|
| 134 |
],
|
| 135 |
"id": "s4fNpOrHhxuP"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
},
|
| 137 |
{
|
| 138 |
"cell_type": "code",
|
|
@@ -143,15 +164,19 @@
|
|
| 143 |
"outputs": [],
|
| 144 |
"source": [
|
| 145 |
"import torch\n",
|
| 146 |
-
"\n",
|
| 147 |
"print('cuda available:', torch.cuda.is_available())\n",
|
| 148 |
"print('device count:', torch.cuda.device_count())\n",
|
| 149 |
"if torch.cuda.is_available():\n",
|
| 150 |
" print('gpu name:', torch.cuda.get_device_name(0))\n",
|
| 151 |
"else:\n",
|
|
|
|
| 152 |
" raise RuntimeError('GPU not available. In Colab, set Runtime > Change runtime type > GPU, then restart.')\n"
|
| 153 |
],
|
| 154 |
"id": "ezOjPxWHhxuQ"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
},
|
| 156 |
{
|
| 157 |
"cell_type": "code",
|
|
@@ -162,30 +187,46 @@
|
|
| 162 |
"outputs": [],
|
| 163 |
"source": [
|
| 164 |
"import json\n",
|
| 165 |
-
"import
|
| 166 |
-
"import random\n",
|
| 167 |
-
"import subprocess\n",
|
| 168 |
-
"from datetime import datetime\n",
|
| 169 |
-
"\n",
|
| 170 |
-
"import pandas as pd\n",
|
| 171 |
-
"from datasets import Dataset, load_dataset\n",
|
| 172 |
"\n",
|
| 173 |
-
"
|
| 174 |
-
"
|
| 175 |
-
"\n"
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
"\n",
|
| 182 |
-
"
|
| 183 |
-
" '
|
| 184 |
-
" '
|
| 185 |
-
" '
|
| 186 |
-
" 'payment_trace',\n",
|
| 187 |
-
" 'policy_review',\n",
|
| 188 |
"]\n",
|
|
|
|
| 189 |
"\n",
|
| 190 |
"\n",
|
| 191 |
"def serialize_observation(observation):\n",
|
|
@@ -366,6 +407,14 @@
|
|
| 366 |
"print(fraudshield_dataset[0]['text'][:900])\n"
|
| 367 |
],
|
| 368 |
"id": "b6x653wbhxuR"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
},
|
| 370 |
{
|
| 371 |
"cell_type": "code",
|
|
@@ -375,11 +424,19 @@
|
|
| 375 |
},
|
| 376 |
"outputs": [],
|
| 377 |
"source": [
|
| 378 |
-
"
|
|
|
|
| 379 |
"\n",
|
| 380 |
-
"
|
| 381 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
"\n",
|
|
|
|
| 383 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 384 |
" model_name=MODEL_NAME,\n",
|
| 385 |
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
|
@@ -583,6 +640,11 @@
|
|
| 583 |
"print('Artifacts saved: reward_curve.png, loss_curve.png, training_summary.json')\n"
|
| 584 |
],
|
| 585 |
"id": "F-zz_6TYhxuV"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
}
|
| 587 |
],
|
| 588 |
"metadata": {
|
|
|
|
| 8 |
"source": [
|
| 9 |
"# FraudShield Colab Training Notebook\n",
|
| 10 |
"\n",
|
| 11 |
+
"This notebook runs the **training-first FraudShield stack** directly from the repo.\n",
|
| 12 |
"\n",
|
| 13 |
+
"It uses:\n",
|
| 14 |
+
"- `train.py` for Colab-friendly curriculum + QLoRA training\n",
|
| 15 |
+
"- `evaluate.py` for fixed-task evaluation\n",
|
| 16 |
+
"- `configs/colab_qlora_grpo.json` for reproducible settings\n",
|
| 17 |
"\n",
|
| 18 |
+
<<<<<<< HEAD
|
| 19 |
"The goal is to learn more than a static heuristic by giving the model broader fraud signals first, then teaching it how to act inside the FraudShield workflow.\n"
|
| 20 |
],
|
| 21 |
"id": "Wadw-uDzhxuI"
|
| 22 |
+
=======
|
| 23 |
+
"The current setup is tuned to favor **FraudShield workflow learning** over generic imitation:\n",
|
| 24 |
+
"- fewer public curriculum rows\n",
|
| 25 |
+
"- more expert FraudShield rollouts\n",
|
| 26 |
+
"- lower learning rate\n",
|
| 27 |
+
"- longer stage-2 adaptation\n"
|
| 28 |
+
]
|
| 29 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 30 |
},
|
| 31 |
{
|
| 32 |
"cell_type": "code",
|
|
|
|
| 109 |
],
|
| 110 |
"source": [
|
| 111 |
"%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
|
| 112 |
+
"%pip install -q -U pip\n",
|
| 113 |
+
"%pip install -q openenv-core matplotlib pandas\n",
|
| 114 |
+
"%pip install -q -e \".[rl]\"\n",
|
| 115 |
"%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 116 |
"\n",
|
| 117 |
"%cd /content\n",
|
| 118 |
"!rm -rf Fraudshield\n",
|
| 119 |
"!git clone https://github.com/DevikaJ2005/Fraudshield.git\n",
|
| 120 |
"%cd /content/Fraudshield\n",
|
| 121 |
+
<<<<<<< HEAD
|
| 122 |
"!ls\n",
|
| 123 |
"%pip install -q -e .\n"
|
| 124 |
],
|
| 125 |
"id": "yqcGck2nhxuN"
|
| 126 |
+
=======
|
| 127 |
+
"!pip install -q -e \".[rl]\"\n"
|
| 128 |
+
]
|
| 129 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
|
|
|
| 138 |
"source": [
|
| 139 |
"import os\n",
|
| 140 |
"from getpass import getpass\n",
|
|
|
|
| 141 |
"from huggingface_hub import login\n",
|
| 142 |
"\n",
|
| 143 |
+
"token = getpass('Enter your HF token (recommended): ')\n",
|
| 144 |
"if token.strip():\n",
|
| 145 |
" os.environ['HF_TOKEN'] = token.strip()\n",
|
| 146 |
" login(token=token.strip())\n",
|
| 147 |
" print('HF login completed.')\n",
|
| 148 |
"else:\n",
|
| 149 |
+
<<<<<<< HEAD
|
| 150 |
" print('Skipping HF login for now.')\n"
|
| 151 |
],
|
| 152 |
"id": "s4fNpOrHhxuP"
|
| 153 |
+
=======
|
| 154 |
+
" print('No HF token provided.')\n"
|
| 155 |
+
]
|
| 156 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 157 |
},
|
| 158 |
{
|
| 159 |
"cell_type": "code",
|
|
|
|
| 164 |
"outputs": [],
|
| 165 |
"source": [
|
| 166 |
"import torch\n",
|
|
|
|
| 167 |
"print('cuda available:', torch.cuda.is_available())\n",
|
| 168 |
"print('device count:', torch.cuda.device_count())\n",
|
| 169 |
"if torch.cuda.is_available():\n",
|
| 170 |
" print('gpu name:', torch.cuda.get_device_name(0))\n",
|
| 171 |
"else:\n",
|
| 172 |
+
<<<<<<< HEAD
|
| 173 |
" raise RuntimeError('GPU not available. In Colab, set Runtime > Change runtime type > GPU, then restart.')\n"
|
| 174 |
],
|
| 175 |
"id": "ezOjPxWHhxuQ"
|
| 176 |
+
=======
|
| 177 |
+
" raise RuntimeError('GPU not available. Enable GPU in Runtime > Change runtime type, then restart.')\n"
|
| 178 |
+
]
|
| 179 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 180 |
},
|
| 181 |
{
|
| 182 |
"cell_type": "code",
|
|
|
|
| 187 |
"outputs": [],
|
| 188 |
"source": [
|
| 189 |
"import json\n",
|
| 190 |
+
"from pathlib import Path\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"\n",
|
| 192 |
+
"config_path = Path('configs/colab_qlora_grpo.json')\n",
|
| 193 |
+
"config = json.loads(config_path.read_text(encoding='utf-8'))\n",
|
| 194 |
+
"print(json.dumps(config, indent=2))\n"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "code",
|
| 199 |
+
"execution_count": null,
|
| 200 |
+
"metadata": {},
|
| 201 |
+
"outputs": [],
|
| 202 |
+
"source": [
|
| 203 |
+
"!python train.py --config configs/colab_qlora_grpo.json"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"!python evaluate.py --config configs/colab_qlora_grpo.json"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "code",
|
| 217 |
+
"execution_count": null,
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"from IPython.display import Image, display\n",
|
| 222 |
+
"!find artifacts -name \"*.png\" -o -name \"*.json\" | sort\n",
|
| 223 |
"\n",
|
| 224 |
+
"paths = [\n",
|
| 225 |
+
" 'artifacts/rl_runs/colab_qlora_grpo/loss_vs_steps.png',\n",
|
| 226 |
+
" 'artifacts/rl_runs/colab_qlora_grpo/reward_vs_steps.png',\n",
|
| 227 |
+
" 'artifacts/plots/evaluation_rewards.png',\n",
|
|
|
|
|
|
|
| 228 |
"]\n",
|
| 229 |
+
<<<<<<< HEAD
|
| 230 |
"\n",
|
| 231 |
"\n",
|
| 232 |
"def serialize_observation(observation):\n",
|
|
|
|
| 407 |
"print(fraudshield_dataset[0]['text'][:900])\n"
|
| 408 |
],
|
| 409 |
"id": "b6x653wbhxuR"
|
| 410 |
+
=======
|
| 411 |
+
"for path in paths:\n",
|
| 412 |
+
" try:\n",
|
| 413 |
+
" display(Image(path))\n",
|
| 414 |
+
" except Exception as exc:\n",
|
| 415 |
+
" print('Could not display', path, exc)\n"
|
| 416 |
+
]
|
| 417 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 418 |
},
|
| 419 |
{
|
| 420 |
"cell_type": "code",
|
|
|
|
| 424 |
},
|
| 425 |
"outputs": [],
|
| 426 |
"source": [
|
| 427 |
+
"import json\n",
|
| 428 |
+
"from pathlib import Path\n",
|
| 429 |
"\n",
|
| 430 |
+
"summary_candidates = [\n",
|
| 431 |
+
" Path('artifacts/rl_runs/colab_qlora_grpo/training_run_summary.json'),\n",
|
| 432 |
+
" Path('artifacts/rl_runs/colab_qlora_grpo/evaluation_report.json'),\n",
|
| 433 |
+
"]\n",
|
| 434 |
+
"for candidate in summary_candidates:\n",
|
| 435 |
+
" if candidate.exists():\n",
|
| 436 |
+
" print(f'===== {candidate} =====')\n",
|
| 437 |
+
" print(json.dumps(json.loads(candidate.read_text(encoding='utf-8')), indent=2)[:12000])\n",
|
| 438 |
"\n",
|
| 439 |
+
<<<<<<< HEAD
|
| 440 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 441 |
" model_name=MODEL_NAME,\n",
|
| 442 |
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
|
|
|
| 640 |
"print('Artifacts saved: reward_curve.png, loss_curve.png, training_summary.json')\n"
|
| 641 |
],
|
| 642 |
"id": "F-zz_6TYhxuV"
|
| 643 |
+
=======
|
| 644 |
+
"!zip -r fraudshield_training_outputs.zip artifacts/rl_runs/colab_qlora_grpo artifacts/plots\n",
|
| 645 |
+
"print('Created fraudshield_training_outputs.zip')\n"
|
| 646 |
+
]
|
| 647 |
+
>>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
|
| 648 |
}
|
| 649 |
],
|
| 650 |
"metadata": {
|
train.py
CHANGED
|
@@ -14,9 +14,111 @@ from datasets import Dataset, load_dataset
|
|
| 14 |
|
| 15 |
from config import ExperimentConfig
|
| 16 |
from environment import FraudShieldTextEnvironment
|
| 17 |
-
from
|
| 18 |
from utils import ensure_dir, save_json, seed_everything
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def build_public_curriculum(config: ExperimentConfig) -> Dataset:
|
| 22 |
"""Load public fraud examples and convert them into action-centric prompts."""
|
|
@@ -56,17 +158,17 @@ def build_public_curriculum(config: ExperimentConfig) -> Dataset:
|
|
| 56 |
|
| 57 |
|
| 58 |
def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
|
| 59 |
-
"""Generate environment-compatible trajectories from
|
| 60 |
|
| 61 |
text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
|
| 62 |
-
agent =
|
| 63 |
rows: list[dict[str, Any]] = []
|
| 64 |
for task_name in config.evaluation.tasks:
|
| 65 |
for _ in range(config.training.warmstart_rollouts_per_task):
|
| 66 |
prompt = text_env.reset(task=task_name)
|
| 67 |
done = False
|
| 68 |
while not done:
|
| 69 |
-
action = agent.decide(text_env
|
| 70 |
payload = {
|
| 71 |
"action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
|
| 72 |
"investigation_target": action.action_type.value,
|
|
|
|
| 14 |
|
| 15 |
from config import ExperimentConfig
|
| 16 |
from environment import FraudShieldTextEnvironment
|
| 17 |
+
from models import ActionTypeEnum, FraudCheckAction
|
| 18 |
from utils import ensure_dir, save_json, seed_everything
|
| 19 |
|
| 20 |
+
class ExpertCurriculumTeacher:
|
| 21 |
+
"""Teacher policy that uses hidden task structure to generate stronger trajectories."""
|
| 22 |
+
|
| 23 |
+
def decide(self, text_env: FraudShieldTextEnvironment) -> FraudCheckAction:
|
| 24 |
+
observation = text_env.current_observation
|
| 25 |
+
case_id = observation.case_id
|
| 26 |
+
revealed = observation.revealed_evidence
|
| 27 |
+
case = text_env.env.workflow_cases[case_id]
|
| 28 |
+
budget = int(observation.app_context.get("investigation_budget_remaining", 0))
|
| 29 |
+
|
| 30 |
+
if "transaction_review" not in revealed:
|
| 31 |
+
return FraudCheckAction(
|
| 32 |
+
case_id=case_id,
|
| 33 |
+
action_type=ActionTypeEnum.REVIEW_TRANSACTION,
|
| 34 |
+
reasoning="Open the transaction details before taking any deeper investigative step.",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
planned_sequence = self._planned_evidence_sequence(case)
|
| 38 |
+
for evidence_key, action_type, reasoning in planned_sequence:
|
| 39 |
+
if evidence_key not in revealed and budget > 0:
|
| 40 |
+
return FraudCheckAction(case_id=case_id, action_type=action_type, reasoning=reasoning)
|
| 41 |
+
|
| 42 |
+
if observation.note_required:
|
| 43 |
+
return FraudCheckAction(
|
| 44 |
+
case_id=case_id,
|
| 45 |
+
action_type=ActionTypeEnum.ADD_CASE_NOTE,
|
| 46 |
+
note_text=self._case_note(case),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return FraudCheckAction(
|
| 50 |
+
case_id=case_id,
|
| 51 |
+
action_type=ActionTypeEnum.RESOLVE_CASE,
|
| 52 |
+
resolution=case["correct_resolution"],
|
| 53 |
+
reasoning=self._resolution_reasoning(case),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def _planned_evidence_sequence(self, case: dict[str, Any]) -> list[tuple[str, ActionTypeEnum, str]]:
|
| 57 |
+
role = case["role"]
|
| 58 |
+
task_specific = [
|
| 59 |
+
(
|
| 60 |
+
"customer_profile",
|
| 61 |
+
ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
|
| 62 |
+
"Customer history is needed to understand whether this pattern reflects risky buyer behavior.",
|
| 63 |
+
),
|
| 64 |
+
(
|
| 65 |
+
"merchant_profile",
|
| 66 |
+
ActionTypeEnum.FETCH_MERCHANT_PROFILE,
|
| 67 |
+
"Merchant health helps explain whether the case risk comes from the seller side.",
|
| 68 |
+
),
|
| 69 |
+
(
|
| 70 |
+
"network_graph",
|
| 71 |
+
ActionTypeEnum.FETCH_NETWORK_GRAPH,
|
| 72 |
+
"Linked-activity evidence is needed to confirm whether this case participates in a broader cluster.",
|
| 73 |
+
),
|
| 74 |
+
(
|
| 75 |
+
"policy_guide",
|
| 76 |
+
ActionTypeEnum.CHECK_POLICY,
|
| 77 |
+
"Policy guidance is required before choosing the final route.",
|
| 78 |
+
),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
if role == "single" and case["correct_resolution"].value == "request_docs":
|
| 82 |
+
return [
|
| 83 |
+
task_specific[0],
|
| 84 |
+
task_specific[3],
|
| 85 |
+
task_specific[1],
|
| 86 |
+
]
|
| 87 |
+
if role == "primary":
|
| 88 |
+
return [
|
| 89 |
+
task_specific[2],
|
| 90 |
+
task_specific[1],
|
| 91 |
+
task_specific[3],
|
| 92 |
+
]
|
| 93 |
+
if role == "secondary":
|
| 94 |
+
return [
|
| 95 |
+
task_specific[2],
|
| 96 |
+
task_specific[0],
|
| 97 |
+
task_specific[3],
|
| 98 |
+
]
|
| 99 |
+
return [
|
| 100 |
+
task_specific[1],
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
def _case_note(self, case: dict[str, Any]) -> str:
|
| 104 |
+
if case["role"] == "primary":
|
| 105 |
+
return "Reviewed the transaction trace, graph evidence, merchant signals, and policy guidance before escalating the linked primary case."
|
| 106 |
+
if case["role"] == "secondary":
|
| 107 |
+
return "Reviewed the transaction trace, graph evidence, customer history, and policy guidance before finalizing the linked secondary case."
|
| 108 |
+
if case["correct_resolution"].value == "request_docs":
|
| 109 |
+
return "Reviewed transaction, customer, merchant, and policy evidence before requesting more supporting documents."
|
| 110 |
+
return "Reviewed the transaction evidence and documented the case before final routing."
|
| 111 |
+
|
| 112 |
+
def _resolution_reasoning(self, case: dict[str, Any]) -> str:
|
| 113 |
+
mapping = {
|
| 114 |
+
"approve": "The collected evidence supports approval without additional intervention.",
|
| 115 |
+
"block": "The combined evidence supports blocking the transaction as high risk.",
|
| 116 |
+
"hold": "The evidence remains risky enough to hold the case for more controlled handling.",
|
| 117 |
+
"request_docs": "The case is ambiguous enough that supporting documents are the safest next step.",
|
| 118 |
+
"escalate": "The linked-cluster evidence and loss risk justify escalation to a higher-touch reviewer.",
|
| 119 |
+
}
|
| 120 |
+
return mapping[case["correct_resolution"].value]
|
| 121 |
+
|
| 122 |
|
| 123 |
def build_public_curriculum(config: ExperimentConfig) -> Dataset:
|
| 124 |
"""Load public fraud examples and convert them into action-centric prompts."""
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
|
| 161 |
+
"""Generate environment-compatible trajectories from an expert teacher."""
|
| 162 |
|
| 163 |
text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
|
| 164 |
+
agent = ExpertCurriculumTeacher()
|
| 165 |
rows: list[dict[str, Any]] = []
|
| 166 |
for task_name in config.evaluation.tasks:
|
| 167 |
for _ in range(config.training.warmstart_rollouts_per_task):
|
| 168 |
prompt = text_env.reset(task=task_name)
|
| 169 |
done = False
|
| 170 |
while not done:
|
| 171 |
+
action = agent.decide(text_env)
|
| 172 |
payload = {
|
| 173 |
"action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
|
| 174 |
"investigation_target": action.action_type.value,
|