DevikaJ2005 commited on
Commit
e97acd1
·
1 Parent(s): 9672a3e

Add training-first RL architecture with tracking

Browse files
configs/colab_qlora_grpo.json CHANGED
@@ -24,13 +24,13 @@
24
  "checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
25
  "save_to_drive": true,
26
  "drive_dir": "/content/drive/MyDrive/fraudshield",
27
- "num_train_epochs": 2,
28
  "per_device_train_batch_size": 2,
29
  "gradient_accumulation_steps": 4,
30
- "learning_rate": 0.0001,
31
  "eval_every_steps": 10,
32
  "save_every_steps": 20,
33
- "warmstart_rollouts_per_task": 24,
34
  "rl_rollouts_per_task": 8,
35
  "max_prompt_tokens": 2048,
36
  "max_completion_tokens": 220,
@@ -41,7 +41,7 @@
41
  "run_name": "fraudshield-colab-run",
42
  "resume_from_checkpoint": null,
43
  "public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
44
- "public_curriculum_rows": 2500
45
  },
46
  "evaluation": {
47
  "tasks": [
 
24
  "checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
25
  "save_to_drive": true,
26
  "drive_dir": "/content/drive/MyDrive/fraudshield",
27
+ "num_train_epochs": 3,
28
  "per_device_train_batch_size": 2,
29
  "gradient_accumulation_steps": 4,
30
+ "learning_rate": 5e-05,
31
  "eval_every_steps": 10,
32
  "save_every_steps": 20,
33
+ "warmstart_rollouts_per_task": 60,
34
  "rl_rollouts_per_task": 8,
35
  "max_prompt_tokens": 2048,
36
  "max_completion_tokens": 220,
 
41
  "run_name": "fraudshield-colab-run",
42
  "resume_from_checkpoint": null,
43
  "public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
44
+ "public_curriculum_rows": 500
45
  },
46
  "evaluation": {
47
  "tasks": [
notebooks/fraudshield_trl_colab.ipynb CHANGED
@@ -8,14 +8,25 @@
8
  "source": [
9
  "# FraudShield Colab Training Notebook\n",
10
  "\n",
11
- "This notebook trains an **open-source LLM policy** for FraudShield using a two-stage curriculum:\n",
12
  "\n",
13
- "1. **Public fraud-data adaptation** from a Hugging Face dataset\n",
14
- "2. **FraudShield policy adaptation** from environment-compatible action traces\n",
 
 
15
  "\n",
 
16
  "The goal is to learn more than a static heuristic by giving the model broader fraud signals first, then teaching it how to act inside the FraudShield workflow.\n"
17
  ],
18
  "id": "Wadw-uDzhxuI"
 
 
 
 
 
 
 
 
19
  },
20
  {
21
  "cell_type": "code",
@@ -98,18 +109,24 @@
98
  ],
99
  "source": [
100
  "%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
101
- "%pip install -q openenv-core datasets peft accelerate sentencepiece matplotlib pandas\n",
102
- "%pip install -q \"transformers==4.51.3\" \"trl==0.19.1\"\n",
 
103
  "%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
104
  "\n",
105
  "%cd /content\n",
106
  "!rm -rf Fraudshield\n",
107
  "!git clone https://github.com/DevikaJ2005/Fraudshield.git\n",
108
  "%cd /content/Fraudshield\n",
 
109
  "!ls\n",
110
  "%pip install -q -e .\n"
111
  ],
112
  "id": "yqcGck2nhxuN"
 
 
 
 
113
  },
114
  {
115
  "cell_type": "code",
@@ -121,18 +138,22 @@
121
  "source": [
122
  "import os\n",
123
  "from getpass import getpass\n",
124
- "\n",
125
  "from huggingface_hub import login\n",
126
  "\n",
127
- "token = getpass('Enter your HF token (optional but recommended): ')\n",
128
  "if token.strip():\n",
129
  " os.environ['HF_TOKEN'] = token.strip()\n",
130
  " login(token=token.strip())\n",
131
  " print('HF login completed.')\n",
132
  "else:\n",
 
133
  " print('Skipping HF login for now.')\n"
134
  ],
135
  "id": "s4fNpOrHhxuP"
 
 
 
 
136
  },
137
  {
138
  "cell_type": "code",
@@ -143,15 +164,19 @@
143
  "outputs": [],
144
  "source": [
145
  "import torch\n",
146
- "\n",
147
  "print('cuda available:', torch.cuda.is_available())\n",
148
  "print('device count:', torch.cuda.device_count())\n",
149
  "if torch.cuda.is_available():\n",
150
  " print('gpu name:', torch.cuda.get_device_name(0))\n",
151
  "else:\n",
 
152
  " raise RuntimeError('GPU not available. In Colab, set Runtime > Change runtime type > GPU, then restart.')\n"
153
  ],
154
  "id": "ezOjPxWHhxuQ"
 
 
 
 
155
  },
156
  {
157
  "cell_type": "code",
@@ -162,30 +187,46 @@
162
  "outputs": [],
163
  "source": [
164
  "import json\n",
165
- "import os\n",
166
- "import random\n",
167
- "import subprocess\n",
168
- "from datetime import datetime\n",
169
- "\n",
170
- "import pandas as pd\n",
171
- "from datasets import Dataset, load_dataset\n",
172
  "\n",
173
- "from fraudshield_env import FraudShieldEnvironment\n",
174
- "from llm_agent import SnapshotCalibratedFraudDetectionAgent\n",
175
- "\n",
176
- "env = FraudShieldEnvironment(data_path='data', seed=42)\n",
177
- "assert env.load_data(), 'FraudShield snapshot failed to load.'\n",
178
- "print('FraudShield loaded:', env.data_loader.get_bundle_summary())\n",
179
- "\n",
180
- "random.seed(42)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  "\n",
182
- "CANONICAL_ALIASES = [\n",
183
- " 'merchant_profile',\n",
184
- " 'customer_profile',\n",
185
- " 'network_graph',\n",
186
- " 'payment_trace',\n",
187
- " 'policy_review',\n",
188
  "]\n",
 
189
  "\n",
190
  "\n",
191
  "def serialize_observation(observation):\n",
@@ -366,6 +407,14 @@
366
  "print(fraudshield_dataset[0]['text'][:900])\n"
367
  ],
368
  "id": "b6x653wbhxuR"
 
 
 
 
 
 
 
 
369
  },
370
  {
371
  "cell_type": "code",
@@ -375,11 +424,19 @@
375
  },
376
  "outputs": [],
377
  "source": [
378
- "from unsloth import FastLanguageModel\n",
 
379
  "\n",
380
- "MODEL_NAME = 'unsloth/Qwen2.5-1.5B-Instruct'\n",
381
- "MAX_SEQ_LENGTH = 2048\n",
 
 
 
 
 
 
382
  "\n",
 
383
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
384
  " model_name=MODEL_NAME,\n",
385
  " max_seq_length=MAX_SEQ_LENGTH,\n",
@@ -583,6 +640,11 @@
583
  "print('Artifacts saved: reward_curve.png, loss_curve.png, training_summary.json')\n"
584
  ],
585
  "id": "F-zz_6TYhxuV"
 
 
 
 
 
586
  }
587
  ],
588
  "metadata": {
 
8
  "source": [
9
  "# FraudShield Colab Training Notebook\n",
10
  "\n",
11
+ "This notebook runs the **training-first FraudShield stack** directly from the repo.\n",
12
  "\n",
13
+ "It uses:\n",
14
+ "- `train.py` for Colab-friendly curriculum + QLoRA training\n",
15
+ "- `evaluate.py` for fixed-task evaluation\n",
16
+ "- `configs/colab_qlora_grpo.json` for reproducible settings\n",
17
  "\n",
18
+ <<<<<<< HEAD
19
  "The goal is to learn more than a static heuristic by giving the model broader fraud signals first, then teaching it how to act inside the FraudShield workflow.\n"
20
  ],
21
  "id": "Wadw-uDzhxuI"
22
+ =======
23
+ "The current setup is tuned to favor **FraudShield workflow learning** over generic imitation:\n",
24
+ "- fewer public curriculum rows\n",
25
+ "- more expert FraudShield rollouts\n",
26
+ "- lower learning rate\n",
27
+ "- longer stage-2 adaptation\n"
28
+ ]
29
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
30
  },
31
  {
32
  "cell_type": "code",
 
109
  ],
110
  "source": [
111
  "%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
112
+ "%pip install -q -U pip\n",
113
+ "%pip install -q openenv-core matplotlib pandas\n",
114
+ "%pip install -q -e \".[rl]\"\n",
115
  "%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
116
  "\n",
117
  "%cd /content\n",
118
  "!rm -rf Fraudshield\n",
119
  "!git clone https://github.com/DevikaJ2005/Fraudshield.git\n",
120
  "%cd /content/Fraudshield\n",
121
+ <<<<<<< HEAD
122
  "!ls\n",
123
  "%pip install -q -e .\n"
124
  ],
125
  "id": "yqcGck2nhxuN"
126
+ =======
127
+ "!pip install -q -e \".[rl]\"\n"
128
+ ]
129
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
130
  },
131
  {
132
  "cell_type": "code",
 
138
  "source": [
139
  "import os\n",
140
  "from getpass import getpass\n",
 
141
  "from huggingface_hub import login\n",
142
  "\n",
143
+ "token = getpass('Enter your HF token (recommended): ')\n",
144
  "if token.strip():\n",
145
  " os.environ['HF_TOKEN'] = token.strip()\n",
146
  " login(token=token.strip())\n",
147
  " print('HF login completed.')\n",
148
  "else:\n",
149
+ <<<<<<< HEAD
150
  " print('Skipping HF login for now.')\n"
151
  ],
152
  "id": "s4fNpOrHhxuP"
153
+ =======
154
+ " print('No HF token provided.')\n"
155
+ ]
156
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
157
  },
158
  {
159
  "cell_type": "code",
 
164
  "outputs": [],
165
  "source": [
166
  "import torch\n",
 
167
  "print('cuda available:', torch.cuda.is_available())\n",
168
  "print('device count:', torch.cuda.device_count())\n",
169
  "if torch.cuda.is_available():\n",
170
  " print('gpu name:', torch.cuda.get_device_name(0))\n",
171
  "else:\n",
172
+ <<<<<<< HEAD
173
  " raise RuntimeError('GPU not available. In Colab, set Runtime > Change runtime type > GPU, then restart.')\n"
174
  ],
175
  "id": "ezOjPxWHhxuQ"
176
+ =======
177
+ " raise RuntimeError('GPU not available. Enable GPU in Runtime > Change runtime type, then restart.')\n"
178
+ ]
179
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
180
  },
181
  {
182
  "cell_type": "code",
 
187
  "outputs": [],
188
  "source": [
189
  "import json\n",
190
+ "from pathlib import Path\n",
 
 
 
 
 
 
191
  "\n",
192
+ "config_path = Path('configs/colab_qlora_grpo.json')\n",
193
+ "config = json.loads(config_path.read_text(encoding='utf-8'))\n",
194
+ "print(json.dumps(config, indent=2))\n"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "!python train.py --config configs/colab_qlora_grpo.json"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "!python evaluate.py --config configs/colab_qlora_grpo.json"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "from IPython.display import Image, display\n",
222
+ "!find artifacts -name \"*.png\" -o -name \"*.json\" | sort\n",
223
  "\n",
224
+ "paths = [\n",
225
+ " 'artifacts/rl_runs/colab_qlora_grpo/loss_vs_steps.png',\n",
226
+ " 'artifacts/rl_runs/colab_qlora_grpo/reward_vs_steps.png',\n",
227
+ " 'artifacts/plots/evaluation_rewards.png',\n",
 
 
228
  "]\n",
229
+ <<<<<<< HEAD
230
  "\n",
231
  "\n",
232
  "def serialize_observation(observation):\n",
 
407
  "print(fraudshield_dataset[0]['text'][:900])\n"
408
  ],
409
  "id": "b6x653wbhxuR"
410
+ =======
411
+ "for path in paths:\n",
412
+ " try:\n",
413
+ " display(Image(path))\n",
414
+ " except Exception as exc:\n",
415
+ " print('Could not display', path, exc)\n"
416
+ ]
417
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
418
  },
419
  {
420
  "cell_type": "code",
 
424
  },
425
  "outputs": [],
426
  "source": [
427
+ "import json\n",
428
+ "from pathlib import Path\n",
429
  "\n",
430
+ "summary_candidates = [\n",
431
+ " Path('artifacts/rl_runs/colab_qlora_grpo/training_run_summary.json'),\n",
432
+ " Path('artifacts/rl_runs/colab_qlora_grpo/evaluation_report.json'),\n",
433
+ "]\n",
434
+ "for candidate in summary_candidates:\n",
435
+ " if candidate.exists():\n",
436
+ " print(f'===== {candidate} =====')\n",
437
+ " print(json.dumps(json.loads(candidate.read_text(encoding='utf-8')), indent=2)[:12000])\n",
438
  "\n",
439
+ <<<<<<< HEAD
440
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
441
  " model_name=MODEL_NAME,\n",
442
  " max_seq_length=MAX_SEQ_LENGTH,\n",
 
640
  "print('Artifacts saved: reward_curve.png, loss_curve.png, training_summary.json')\n"
641
  ],
642
  "id": "F-zz_6TYhxuV"
643
+ =======
644
+ "!zip -r fraudshield_training_outputs.zip artifacts/rl_runs/colab_qlora_grpo artifacts/plots\n",
645
+ "print('Created fraudshield_training_outputs.zip')\n"
646
+ ]
647
+ >>>>>>> 43cc51d (Use expert teacher rollouts for stronger retraining)
648
  }
649
  ],
650
  "metadata": {
train.py CHANGED
@@ -14,9 +14,111 @@ from datasets import Dataset, load_dataset
14
 
15
  from config import ExperimentConfig
16
  from environment import FraudShieldTextEnvironment
17
- from llm_agent import SnapshotCalibratedFraudDetectionAgent
18
  from utils import ensure_dir, save_json, seed_everything
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def build_public_curriculum(config: ExperimentConfig) -> Dataset:
22
  """Load public fraud examples and convert them into action-centric prompts."""
@@ -56,17 +158,17 @@ def build_public_curriculum(config: ExperimentConfig) -> Dataset:
56
 
57
 
58
  def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
59
- """Generate environment-compatible trajectories from the calibrated baseline."""
60
 
61
  text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
62
- agent = SnapshotCalibratedFraudDetectionAgent()
63
  rows: list[dict[str, Any]] = []
64
  for task_name in config.evaluation.tasks:
65
  for _ in range(config.training.warmstart_rollouts_per_task):
66
  prompt = text_env.reset(task=task_name)
67
  done = False
68
  while not done:
69
- action = agent.decide(text_env.current_observation)
70
  payload = {
71
  "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
72
  "investigation_target": action.action_type.value,
 
14
 
15
  from config import ExperimentConfig
16
  from environment import FraudShieldTextEnvironment
17
+ from models import ActionTypeEnum, FraudCheckAction
18
  from utils import ensure_dir, save_json, seed_everything
19
 
20
+ class ExpertCurriculumTeacher:
21
+ """Teacher policy that uses hidden task structure to generate stronger trajectories."""
22
+
23
+ def decide(self, text_env: FraudShieldTextEnvironment) -> FraudCheckAction:
24
+ observation = text_env.current_observation
25
+ case_id = observation.case_id
26
+ revealed = observation.revealed_evidence
27
+ case = text_env.env.workflow_cases[case_id]
28
+ budget = int(observation.app_context.get("investigation_budget_remaining", 0))
29
+
30
+ if "transaction_review" not in revealed:
31
+ return FraudCheckAction(
32
+ case_id=case_id,
33
+ action_type=ActionTypeEnum.REVIEW_TRANSACTION,
34
+ reasoning="Open the transaction details before taking any deeper investigative step.",
35
+ )
36
+
37
+ planned_sequence = self._planned_evidence_sequence(case)
38
+ for evidence_key, action_type, reasoning in planned_sequence:
39
+ if evidence_key not in revealed and budget > 0:
40
+ return FraudCheckAction(case_id=case_id, action_type=action_type, reasoning=reasoning)
41
+
42
+ if observation.note_required:
43
+ return FraudCheckAction(
44
+ case_id=case_id,
45
+ action_type=ActionTypeEnum.ADD_CASE_NOTE,
46
+ note_text=self._case_note(case),
47
+ )
48
+
49
+ return FraudCheckAction(
50
+ case_id=case_id,
51
+ action_type=ActionTypeEnum.RESOLVE_CASE,
52
+ resolution=case["correct_resolution"],
53
+ reasoning=self._resolution_reasoning(case),
54
+ )
55
+
56
+ def _planned_evidence_sequence(self, case: dict[str, Any]) -> list[tuple[str, ActionTypeEnum, str]]:
57
+ role = case["role"]
58
+ task_specific = [
59
+ (
60
+ "customer_profile",
61
+ ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
62
+ "Customer history is needed to understand whether this pattern reflects risky buyer behavior.",
63
+ ),
64
+ (
65
+ "merchant_profile",
66
+ ActionTypeEnum.FETCH_MERCHANT_PROFILE,
67
+ "Merchant health helps explain whether the case risk comes from the seller side.",
68
+ ),
69
+ (
70
+ "network_graph",
71
+ ActionTypeEnum.FETCH_NETWORK_GRAPH,
72
+ "Linked-activity evidence is needed to confirm whether this case participates in a broader cluster.",
73
+ ),
74
+ (
75
+ "policy_guide",
76
+ ActionTypeEnum.CHECK_POLICY,
77
+ "Policy guidance is required before choosing the final route.",
78
+ ),
79
+ ]
80
+
81
+ if role == "single" and case["correct_resolution"].value == "request_docs":
82
+ return [
83
+ task_specific[0],
84
+ task_specific[3],
85
+ task_specific[1],
86
+ ]
87
+ if role == "primary":
88
+ return [
89
+ task_specific[2],
90
+ task_specific[1],
91
+ task_specific[3],
92
+ ]
93
+ if role == "secondary":
94
+ return [
95
+ task_specific[2],
96
+ task_specific[0],
97
+ task_specific[3],
98
+ ]
99
+ return [
100
+ task_specific[1],
101
+ ]
102
+
103
+ def _case_note(self, case: dict[str, Any]) -> str:
104
+ if case["role"] == "primary":
105
+ return "Reviewed the transaction trace, graph evidence, merchant signals, and policy guidance before escalating the linked primary case."
106
+ if case["role"] == "secondary":
107
+ return "Reviewed the transaction trace, graph evidence, customer history, and policy guidance before finalizing the linked secondary case."
108
+ if case["correct_resolution"].value == "request_docs":
109
+ return "Reviewed transaction, customer, merchant, and policy evidence before requesting more supporting documents."
110
+ return "Reviewed the transaction evidence and documented the case before final routing."
111
+
112
+ def _resolution_reasoning(self, case: dict[str, Any]) -> str:
113
+ mapping = {
114
+ "approve": "The collected evidence supports approval without additional intervention.",
115
+ "block": "The combined evidence supports blocking the transaction as high risk.",
116
+ "hold": "The evidence remains risky enough to hold the case for more controlled handling.",
117
+ "request_docs": "The case is ambiguous enough that supporting documents are the safest next step.",
118
+ "escalate": "The linked-cluster evidence and loss risk justify escalation to a higher-touch reviewer.",
119
+ }
120
+ return mapping[case["correct_resolution"].value]
121
+
122
 
123
  def build_public_curriculum(config: ExperimentConfig) -> Dataset:
124
  """Load public fraud examples and convert them into action-centric prompts."""
 
158
 
159
 
160
  def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
161
+ """Generate environment-compatible trajectories from an expert teacher."""
162
 
163
  text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
164
+ agent = ExpertCurriculumTeacher()
165
  rows: list[dict[str, Any]] = []
166
  for task_name in config.evaluation.tasks:
167
  for _ in range(config.training.warmstart_rollouts_per_task):
168
  prompt = text_env.reset(task=task_name)
169
  done = False
170
  while not done:
171
+ action = agent.decide(text_env)
172
  payload = {
173
  "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
174
  "investigation_target": action.action_type.value,