DevikaJ2005 commited on
Commit
ccd0934
·
1 Parent(s): 4049a11

Add public-data curriculum and harden LLM agent

Browse files
Files changed (3) hide show
  1. README.md +39 -5
  2. llm_agent_openai.py +105 -25
  3. notebooks/fraudshield_trl_colab.ipynb +214 -71
README.md CHANGED
@@ -122,6 +122,31 @@ Run the heuristic or configured agent:
122
  python inference.py
123
  ```
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  Run the OpenEnv API server:
126
 
127
  ```bash
@@ -164,9 +189,12 @@ It is designed to:
164
 
165
  1. install `openenv-core`, `trl`, `unsloth`, `transformers`, `datasets`, and `peft`
166
  2. clone the repo and install FraudShield
167
- 3. wrap `FraudShieldEnvironment` for GRPO-style training
168
- 4. combine environment reward with a JSON-format reward
169
- 5. train easy -> medium -> hard
 
 
 
170
  6. save:
171
  - `reward_curve.png`
172
  - `loss_curve.png`
@@ -175,7 +203,7 @@ It is designed to:
175
  - heuristic via `python inference.py`
176
  - trained model via `LOCAL_MODEL_PATH=... python inference.py`
177
 
178
- At the moment the notebook is wired for this flow, but the real training run still depends on available compute credits.
179
 
180
  ## Results
181
 
@@ -186,7 +214,7 @@ Current heuristic baseline, measured with `python inference.py`:
186
  - Hard: `0.7425`
187
  - Final: `0.6942`
188
 
189
- This baseline is intentionally rule-based and not trained. It is strong on easy, weaker on medium, and still imperfect on hard, which leaves headroom for the training run.
190
 
191
  Once training is completed, this section should include:
192
 
@@ -195,6 +223,12 @@ Once training is completed, this section should include:
195
  - trained-vs-heuristic comparison table
196
  - one short qualitative trace comparison
197
 
 
 
 
 
 
 
198
  ## Live Links
199
 
200
  - Hugging Face Space: `https://huggingface.co/spaces/DevikaJ2005/fraudshield-1`
 
122
  python inference.py
123
  ```
124
 
125
+ FraudShield supports three agent modes:
126
+
127
+ - `heuristic` by default when no model credentials are set
128
+ - `llm_local` when `LOCAL_MODEL_PATH` points to a trained Hugging Face / PEFT checkpoint
129
+ - `llm_remote` when an API-compatible model is configured
130
+
131
+ For a no-paid-model open-source setup, the recommended options are:
132
+
133
+ ### Option 1: Use your locally trained model
134
+
135
+ ```bash
136
+ LOCAL_MODEL_PATH=trained_policy python inference.py
137
+ ```
138
+
139
+ ### Option 2: Use a Hugging Face hosted open-source model
140
+
141
+ ```bash
142
+ HF_TOKEN=your_token_here \
143
+ MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct \
144
+ API_BASE_URL=https://router.huggingface.co/v1 \
145
+ python inference.py
146
+ ```
147
+
148
+ If `HF_TOKEN` is present and `API_BASE_URL` is not set, FraudShield defaults to the Hugging Face router automatically.
149
+
150
  Run the OpenEnv API server:
151
 
152
  ```bash
 
189
 
190
  1. install `openenv-core`, `trl`, `unsloth`, `transformers`, `datasets`, and `peft`
191
  2. clone the repo and install FraudShield
192
+ 3. load a public fraud curriculum dataset from Hugging Face
193
+ 4. build a second-stage training set from real FraudShield rollouts
194
+ 5. run two-stage fine-tuning with Unsloth LoRA and TRL `SFTTrainer`
195
+ - stage 1: public fraud-data adaptation
196
+ - stage 2: FraudShield policy adaptation
197
+ 5. save a reusable local policy checkpoint
198
  6. save:
199
  - `reward_curve.png`
200
  - `loss_curve.png`
 
203
  - heuristic via `python inference.py`
204
  - trained model via `LOCAL_MODEL_PATH=... python inference.py`
205
 
206
+ The notebook is designed for Colab + GPU execution and does not require a paid proprietary LLM. The current public curriculum source is `Phoenix21/mock_fraud-detection-dataset`, which gives the model broader fraud-signal exposure before it is adapted to FraudShield actions.
207
 
208
  ## Results
209
 
 
214
  - Hard: `0.7425`
215
  - Final: `0.6942`
216
 
217
+ This baseline is intentionally rule-based and not trained. It is strong on easy, weaker on medium, and still imperfect on hard, which leaves headroom for a trained policy that can learn broader fraud patterns from public data and then adapt them to FraudShield.
218
 
219
  Once training is completed, this section should include:
220
 
 
223
  - trained-vs-heuristic comparison table
224
  - one short qualitative trace comparison
225
 
226
+ The preferred final story is:
227
+
228
+ - heuristic baseline
229
+ - base open-source LLM or hosted HF model
230
+ - fine-tuned local policy checkpoint
231
+
232
  ## Live Links
233
 
234
  - Hugging Face Space: `https://huggingface.co/spaces/DevikaJ2005/fraudshield-1`
llm_agent_openai.py CHANGED
@@ -4,6 +4,7 @@ from __future__ import annotations
4
 
5
  import json
6
  import logging
 
7
  from pathlib import Path
8
  from typing import Any, Dict, Optional
9
 
@@ -16,6 +17,38 @@ except ImportError: # pragma: no cover
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class LLMFraudDetectionAgent:
21
  """OpenAI-compatible LLM agent with heuristic fallback."""
@@ -61,6 +94,7 @@ class LLMFraudDetectionAgent:
61
  return self._fallback(observation, exc)
62
 
63
  def _build_messages(self, observation) -> list[Dict[str, str]]:
 
64
  observation_payload = {
65
  "case_id": observation.case_id,
66
  "task_name": observation.task_name.value,
@@ -73,21 +107,15 @@ class LLMFraudDetectionAgent:
73
  "remaining_sla": observation.remaining_sla,
74
  "note_required": observation.note_required,
75
  "allowed_public_actions": [action.value for action in observation.allowed_actions],
76
- "available_investigation_aliases": [
77
- "merchant_profile",
78
- "customer_profile",
79
- "network_graph",
80
- "device_intel",
81
- "payment_trace",
82
- "fulfillment_review",
83
- "policy_review",
84
- ],
85
  "app_context": observation.app_context,
86
  }
87
  system_prompt = (
88
  "You are a fraud analyst operating inside a simulated investigation workflow. "
89
  "Only use the visible evidence shown to you. Choose either one investigation alias or one final "
90
- "decision. Respond with JSON only using this schema: "
 
 
91
  '{"action_type":"investigate|decide","investigation_target":"string|null",'
92
  '"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
93
  )
@@ -101,7 +129,8 @@ class LLMFraudDetectionAgent:
101
  reasoning = self._normalize_reasoning(payload.get("reasoning"))
102
  if action_type == "investigate":
103
  investigation_target = str(payload.get("investigation_target", "")).strip().lower()
104
- mapped_action = self._map_investigation_alias(investigation_target)
 
105
  return FraudCheckAction(
106
  case_id=observation.case_id,
107
  action_type=mapped_action,
@@ -128,20 +157,71 @@ class LLMFraudDetectionAgent:
128
 
129
  raise ValueError(f"Unsupported action_type from model: {action_type!r}")
130
 
131
- def _map_investigation_alias(self, alias: str) -> ActionTypeEnum:
132
- mapping = {
133
- "merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
134
- "customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
135
- "network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
136
- "device_intel": ActionTypeEnum.FETCH_NETWORK_GRAPH,
137
- "payment_trace": ActionTypeEnum.REVIEW_TRANSACTION,
138
- "fulfillment_review": ActionTypeEnum.REVIEW_TRANSACTION,
139
- "policy_review": ActionTypeEnum.CHECK_POLICY,
140
- "trust_notes": ActionTypeEnum.CHECK_POLICY,
141
- }
142
- if alias not in mapping:
143
- raise ValueError(f"Unsupported investigation_target from model: {alias!r}")
144
- return mapping[alias]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def _map_decision_to_resolution(self, decision: str, confidence: float, observation) -> ResolutionEnum:
147
  if decision not in {"fraud", "legitimate"}:
 
4
 
5
  import json
6
  import logging
7
+ import re
8
  from pathlib import Path
9
  from typing import Any, Dict, Optional
10
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
+ ACTION_ALIAS_TO_ENUM = {
21
+ "merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
22
+ "fetch_merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
23
+ "customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
24
+ "fetch_customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
25
+ "network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
26
+ "fetch_network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
27
+ "device_intel": ActionTypeEnum.FETCH_NETWORK_GRAPH,
28
+ "payment_trace": ActionTypeEnum.REVIEW_TRANSACTION,
29
+ "fulfillment_review": ActionTypeEnum.REVIEW_TRANSACTION,
30
+ "review_transaction": ActionTypeEnum.REVIEW_TRANSACTION,
31
+ "policy_review": ActionTypeEnum.CHECK_POLICY,
32
+ "check_policy": ActionTypeEnum.CHECK_POLICY,
33
+ "trust_notes": ActionTypeEnum.CHECK_POLICY,
34
+ }
35
+
36
+ ACTION_ENUM_TO_ALIAS = {
37
+ ActionTypeEnum.REVIEW_TRANSACTION: "payment_trace",
38
+ ActionTypeEnum.FETCH_CUSTOMER_PROFILE: "customer_profile",
39
+ ActionTypeEnum.FETCH_MERCHANT_PROFILE: "merchant_profile",
40
+ ActionTypeEnum.FETCH_NETWORK_GRAPH: "network_graph",
41
+ ActionTypeEnum.CHECK_POLICY: "policy_review",
42
+ }
43
+
44
+ ACTION_ENUM_TO_EVIDENCE_KEY = {
45
+ ActionTypeEnum.REVIEW_TRANSACTION: "transaction_review",
46
+ ActionTypeEnum.FETCH_CUSTOMER_PROFILE: "customer_profile",
47
+ ActionTypeEnum.FETCH_MERCHANT_PROFILE: "merchant_profile",
48
+ ActionTypeEnum.FETCH_NETWORK_GRAPH: "network_graph",
49
+ ActionTypeEnum.CHECK_POLICY: "policy_guide",
50
+ }
51
+
52
 
53
  class LLMFraudDetectionAgent:
54
  """OpenAI-compatible LLM agent with heuristic fallback."""
 
94
  return self._fallback(observation, exc)
95
 
96
  def _build_messages(self, observation) -> list[Dict[str, str]]:
97
+ available_aliases = self._available_investigation_aliases(observation)
98
  observation_payload = {
99
  "case_id": observation.case_id,
100
  "task_name": observation.task_name.value,
 
107
  "remaining_sla": observation.remaining_sla,
108
  "note_required": observation.note_required,
109
  "allowed_public_actions": [action.value for action in observation.allowed_actions],
110
+ "available_investigation_aliases": available_aliases,
 
 
 
 
 
 
 
 
111
  "app_context": observation.app_context,
112
  }
113
  system_prompt = (
114
  "You are a fraud analyst operating inside a simulated investigation workflow. "
115
  "Only use the visible evidence shown to you. Choose either one investigation alias or one final "
116
+ "decision. For investigation_target, you must return exactly one alias from "
117
+ f"{available_aliases}. Never return placeholders, array expressions, or prose such as "
118
+ "'available_investigations[0]'. Respond with JSON only using this schema: "
119
  '{"action_type":"investigate|decide","investigation_target":"string|null",'
120
  '"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
121
  )
 
129
  reasoning = self._normalize_reasoning(payload.get("reasoning"))
130
  if action_type == "investigate":
131
  investigation_target = str(payload.get("investigation_target", "")).strip().lower()
132
+ mapped_action = self._map_investigation_alias(investigation_target, observation)
133
+ mapped_action = self._stabilize_investigation_choice(mapped_action, observation)
134
  return FraudCheckAction(
135
  case_id=observation.case_id,
136
  action_type=mapped_action,
 
157
 
158
  raise ValueError(f"Unsupported action_type from model: {action_type!r}")
159
 
160
+ def _map_investigation_alias(self, alias: str, observation) -> ActionTypeEnum:
161
+ normalized = alias.strip().lower()
162
+ if normalized in ACTION_ALIAS_TO_ENUM:
163
+ return ACTION_ALIAS_TO_ENUM[normalized]
164
+
165
+ placeholder_match = re.fullmatch(r"available_investigations\[(\d+)\]", normalized)
166
+ if placeholder_match:
167
+ index = int(placeholder_match.group(1))
168
+ available = self._available_investigation_aliases(observation)
169
+ if 0 <= index < len(available):
170
+ return ACTION_ALIAS_TO_ENUM[available[index]]
171
+
172
+ compact = re.sub(r"[^a-z_]", "", normalized.replace("-", "_").replace(" ", "_"))
173
+ for key, value in ACTION_ALIAS_TO_ENUM.items():
174
+ if compact == key:
175
+ return value
176
+ for key, value in ACTION_ALIAS_TO_ENUM.items():
177
+ if compact and compact in key:
178
+ return value
179
+
180
+ available = self._available_investigation_aliases(observation)
181
+ if len(available) == 1:
182
+ return ACTION_ALIAS_TO_ENUM[available[0]]
183
+ raise ValueError(f"Unsupported investigation_target from model: {alias!r}")
184
+
185
+ def _available_investigation_aliases(self, observation) -> list[str]:
186
+ context_aliases = observation.app_context.get("available_investigations")
187
+ aliases: list[str] = []
188
+ if isinstance(context_aliases, list):
189
+ for alias in context_aliases:
190
+ normalized = str(alias).strip().lower()
191
+ if normalized in ACTION_ALIAS_TO_ENUM:
192
+ canonical = ACTION_ENUM_TO_ALIAS[ACTION_ALIAS_TO_ENUM[normalized]]
193
+ if canonical not in aliases:
194
+ aliases.append(canonical)
195
+
196
+ if aliases:
197
+ return aliases
198
+
199
+ fallback_aliases: list[str] = []
200
+ for action in observation.allowed_actions:
201
+ if action in ACTION_ENUM_TO_ALIAS:
202
+ alias = ACTION_ENUM_TO_ALIAS[action]
203
+ if alias not in fallback_aliases:
204
+ fallback_aliases.append(alias)
205
+ return fallback_aliases
206
+
207
+ def _stabilize_investigation_choice(self, action_type: ActionTypeEnum, observation) -> ActionTypeEnum:
208
+ evidence_key = ACTION_ENUM_TO_EVIDENCE_KEY.get(action_type)
209
+ if evidence_key and evidence_key not in observation.revealed_evidence:
210
+ return action_type
211
+
212
+ alternatives = []
213
+ for alias in self._available_investigation_aliases(observation):
214
+ candidate = ACTION_ALIAS_TO_ENUM[alias]
215
+ candidate_key = ACTION_ENUM_TO_EVIDENCE_KEY.get(candidate)
216
+ if candidate_key and candidate_key not in observation.revealed_evidence:
217
+ alternatives.append(candidate)
218
+
219
+ if alternatives:
220
+ return alternatives[0]
221
+
222
+ raise ValueError(
223
+ f"Investigation {action_type.value!r} is already revealed and no unseen investigations remain."
224
+ )
225
 
226
  def _map_decision_to_resolution(self, decision: str, confidence: float, observation) -> ResolutionEnum:
227
  if decision not in {"fraud", "legitimate"}:
notebooks/fraudshield_trl_colab.ipynb CHANGED
@@ -6,9 +6,12 @@
6
  "source": [
7
  "# FraudShield Colab Training Notebook\n",
8
  "\n",
9
- "This notebook uses **Unsloth + TRL** to fine-tune a small instruction model to imitate strong investigation trajectories in FraudShield.\n",
10
  "\n",
11
- "It is designed for **reliable Colab execution** first: install dependencies, build a training set from real environment rollouts, fine-tune a LoRA policy, evaluate heuristic vs trained policy, and save the expected training artifacts.\n"
 
 
 
12
  ]
13
  },
14
  {
@@ -18,7 +21,7 @@
18
  "outputs": [],
19
  "source": [
20
  "%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
21
- "%pip install -q openenv-core datasets peft accelerate sentencepiece\n",
22
  "%pip install -q \"transformers==4.51.3\" \"trl==0.19.1\"\n",
23
  "%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
24
  "\n",
@@ -74,10 +77,12 @@
74
  "source": [
75
  "import json\n",
76
  "import os\n",
 
77
  "import subprocess\n",
78
  "from datetime import datetime\n",
79
  "\n",
80
- "from datasets import Dataset\n",
 
81
  "\n",
82
  "from fraudshield_env import FraudShieldEnvironment\n",
83
  "from llm_agent import SnapshotCalibratedFraudDetectionAgent\n",
@@ -86,6 +91,17 @@
86
  "assert env.load_data(), 'FraudShield snapshot failed to load.'\n",
87
  "print('FraudShield loaded:', env.data_loader.get_bundle_summary())\n",
88
  "\n",
 
 
 
 
 
 
 
 
 
 
 
89
  "def serialize_observation(observation):\n",
90
  " return json.dumps(\n",
91
  " {\n",
@@ -102,74 +118,166 @@
102
  " 'case_summary': observation.case_summary.model_dump(mode='json'),\n",
103
  " 'app_context': observation.app_context,\n",
104
  " },\n",
105
- " ensure_ascii=True,\n",
106
- " indent=2,\n",
107
  " )\n",
108
  "\n",
 
109
  "def prompt_from_observation(observation):\n",
 
110
  " return (\n",
111
- " 'You are a fraud analyst working in a simulated investigation workflow.\\n'\n",
112
- " 'Choose the next best action based only on the visible observation.\\n'\n",
113
- " 'Respond with JSON only using keys action_type, investigation_target, decision, confidence, reasoning.\\n'\n",
114
- " 'Use action_type investigate or decide.\\n\\n'\n",
115
- " 'Observation:\\n'\n",
116
- " f\"{serialize_observation(observation)}\\n\"\n",
117
  " )\n",
118
  "\n",
119
- "def action_to_target_json(action):\n",
120
- " payload = {\n",
121
- " 'action_type': 'investigate',\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  " 'investigation_target': None,\n",
123
- " 'decision': None,\n",
124
- " 'confidence': 0.5,\n",
125
- " 'reasoning': action.reasoning or '',\n",
126
  " }\n",
127
- " if action.action_type.value == 'fetch_customer_profile':\n",
128
- " payload['investigation_target'] = 'customer_profile'\n",
129
- " elif action.action_type.value == 'fetch_merchant_profile':\n",
130
- " payload['investigation_target'] = 'merchant_profile'\n",
131
- " elif action.action_type.value == 'fetch_network_graph':\n",
132
- " payload['investigation_target'] = 'network_graph'\n",
133
- " elif action.action_type.value == 'check_policy':\n",
134
- " payload['investigation_target'] = 'policy_review'\n",
135
- " elif action.action_type.value == 'review_transaction':\n",
136
- " payload['investigation_target'] = 'payment_trace'\n",
137
- " elif action.action_type.value == 'add_case_note':\n",
138
- " payload['investigation_target'] = 'trust_notes'\n",
139
- " payload['reasoning'] = action.note_text or payload['reasoning']\n",
140
- " elif action.action_type.value == 'resolve_case':\n",
141
- " payload['action_type'] = 'decide'\n",
142
- " payload['investigation_target'] = None\n",
143
- " if action.resolution.value in {'approve', 'request_docs'}:\n",
144
- " payload['decision'] = 'legitimate'\n",
145
- " payload['confidence'] = 0.8 if action.resolution.value == 'approve' else 0.6\n",
146
- " else:\n",
147
- " payload['decision'] = 'fraud'\n",
148
- " payload['confidence'] = 0.9 if action.resolution.value in {'block', 'escalate'} else 0.6\n",
149
- " return json.dumps(payload, ensure_ascii=True)\n",
150
- "\n",
151
- "def build_training_dataset(per_task_episodes=18):\n",
152
  " agent = SnapshotCalibratedFraudDetectionAgent()\n",
153
- " records = []\n",
154
  " for task_name in ('easy', 'medium', 'hard'):\n",
155
- " for episode_idx in range(per_task_episodes):\n",
156
- " rollout_env = FraudShieldEnvironment(data_path='data', seed=42 + episode_idx)\n",
157
- " rollout_env.load_data()\n",
158
- " observation = rollout_env.reset(task_name).observation\n",
159
- " while not rollout_env.is_done:\n",
160
  " action = agent.decide(observation)\n",
161
- " records.append({\n",
 
 
 
162
  " 'task_name': task_name,\n",
163
- " 'prompt': prompt_from_observation(observation),\n",
164
- " 'target': action_to_target_json(action),\n",
165
- " 'text': prompt_from_observation(observation) + action_to_target_json(action),\n",
166
  " })\n",
167
- " observation = rollout_env.step(action).observation\n",
168
- " return Dataset.from_list(records)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  "\n",
170
- "train_dataset = build_training_dataset(per_task_episodes=18)\n",
171
- "print(train_dataset)\n",
172
- "print(train_dataset[0]['text'][:1200])\n"
 
 
 
 
 
 
 
 
 
 
 
173
  ]
174
  },
175
  {
@@ -209,30 +317,56 @@
209
  "from transformers import TrainingArguments\n",
210
  "from trl import SFTTrainer\n",
211
  "\n",
212
- "training_args = TrainingArguments(\n",
213
- " output_dir='fraudshield-sft-run',\n",
214
- " num_train_epochs=3,\n",
215
  " per_device_train_batch_size=2,\n",
216
  " gradient_accumulation_steps=4,\n",
217
  " learning_rate=2e-4,\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  " logging_steps=1,\n",
219
  " save_strategy='epoch',\n",
220
  " report_to='none',\n",
221
  " fp16=not torch.cuda.is_bf16_supported(),\n",
222
  " bf16=torch.cuda.is_bf16_supported(),\n",
223
- " max_steps=-1,\n",
224
  " warmup_ratio=0.05,\n",
225
  " lr_scheduler_type='cosine',\n",
226
  ")\n",
227
  "\n",
228
  "trainer = SFTTrainer(\n",
229
- " model=model,\n",
230
  " tokenizer=tokenizer,\n",
231
- " train_dataset=train_dataset,\n",
232
  " dataset_text_field='text',\n",
233
  " max_seq_length=MAX_SEQ_LENGTH,\n",
234
  " packing=False,\n",
235
- " args=training_args,\n",
236
  ")\n",
237
  "\n",
238
  "trainer.train()\n",
@@ -261,7 +395,8 @@
261
  " check=True,\n",
262
  " )\n",
263
  " with open('fraudshield_baseline_results.json', 'r', encoding='utf-8') as handle:\n",
264
- " return json.load(handle), completed.stdout\n",
 
265
  "\n",
266
  "baseline_results, baseline_stdout = run_inference()\n",
267
  "trained_results, trained_stdout = run_inference({'LOCAL_MODEL_PATH': OUTPUT_DIR})\n",
@@ -275,8 +410,10 @@
275
  " 'delta': trained_results[task_name]['score'] - baseline_results[task_name]['score'],\n",
276
  " })\n",
277
  "\n",
278
- "print('Heuristic baseline stdout:\\n', baseline_stdout)\n",
279
- "print('Trained model stdout:\\n', trained_stdout)\n",
 
 
280
  "print(json.dumps(comparison_rows, indent=2))\n"
281
  ]
282
  },
@@ -316,8 +453,9 @@
316
  "summary = {\n",
317
  " 'status': 'completed',\n",
318
  " 'updated_at': datetime.utcnow().isoformat() + 'Z',\n",
319
- " 'trainer': 'TRL SFTTrainer with Unsloth LoRA',\n",
320
  " 'base_model': MODEL_NAME,\n",
 
321
  " 'local_model_path': OUTPUT_DIR,\n",
322
  " 'baseline': {\n",
323
  " 'easy': baseline_results['easy']['score'],\n",
@@ -343,7 +481,8 @@
343
  "with open('training_summary.json', 'w', encoding='utf-8') as handle:\n",
344
  " json.dump(summary, handle, indent=2)\n",
345
  "\n",
346
- "print(json.dumps(summary, indent=2))\n"
 
347
  ]
348
  }
349
  ],
@@ -356,8 +495,12 @@
356
  "language_info": {
357
  "name": "python",
358
  "version": "3.12"
 
 
 
 
359
  }
360
  },
361
  "nbformat": 4,
362
  "nbformat_minor": 5
363
- }
 
6
  "source": [
7
  "# FraudShield Colab Training Notebook\n",
8
  "\n",
9
+ "This notebook trains an **open-source LLM policy** for FraudShield using a two-stage curriculum:\n",
10
  "\n",
11
+ "1. **Public fraud-data adaptation** from a Hugging Face dataset\n",
12
+ "2. **FraudShield policy adaptation** from environment-compatible action traces\n",
13
+ "\n",
14
+ "The goal is to learn more than a static heuristic by giving the model broader fraud signals first, then teaching it how to act inside the FraudShield workflow.\n"
15
  ]
16
  },
17
  {
 
21
  "outputs": [],
22
  "source": [
23
  "%pip uninstall -y unsloth unsloth_zoo trl transformers tokenizers\n",
24
+ "%pip install -q openenv-core datasets peft accelerate sentencepiece matplotlib pandas\n",
25
  "%pip install -q \"transformers==4.51.3\" \"trl==0.19.1\"\n",
26
  "%pip install -q \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
27
  "\n",
 
77
  "source": [
78
  "import json\n",
79
  "import os\n",
80
+ "import random\n",
81
  "import subprocess\n",
82
  "from datetime import datetime\n",
83
  "\n",
84
+ "import pandas as pd\n",
85
+ "from datasets import Dataset, load_dataset\n",
86
  "\n",
87
  "from fraudshield_env import FraudShieldEnvironment\n",
88
  "from llm_agent import SnapshotCalibratedFraudDetectionAgent\n",
 
91
  "assert env.load_data(), 'FraudShield snapshot failed to load.'\n",
92
  "print('FraudShield loaded:', env.data_loader.get_bundle_summary())\n",
93
  "\n",
94
+ "random.seed(42)\n",
95
+ "\n",
96
+ "CANONICAL_ALIASES = [\n",
97
+ " 'merchant_profile',\n",
98
+ " 'customer_profile',\n",
99
+ " 'network_graph',\n",
100
+ " 'payment_trace',\n",
101
+ " 'policy_review',\n",
102
+ "]\n",
103
+ "\n",
104
+ "\n",
105
  "def serialize_observation(observation):\n",
106
  " return json.dumps(\n",
107
  " {\n",
 
118
  " 'case_summary': observation.case_summary.model_dump(mode='json'),\n",
119
  " 'app_context': observation.app_context,\n",
120
  " },\n",
121
+ " sort_keys=True,\n",
 
122
  " )\n",
123
  "\n",
124
+ "\n",
125
  "def prompt_from_observation(observation):\n",
126
+ " available = observation.app_context.get('available_investigations', CANONICAL_ALIASES)\n",
127
  " return (\n",
128
+ " 'You are a fraud analyst operating in a simulated investigation workflow. '\n",
129
+ " 'Only use visible evidence. Return JSON only.\\n\\n'\n",
130
+ " f'Visible observation:\\n{serialize_observation(observation)}\\n\\n'\n",
131
+ " f'Valid investigation aliases: {available}.\\n'\n",
132
+ " 'Respond with JSON using this schema: '\n",
133
+ " '{\"action_type\":\"investigate|decide\",\"investigation_target\":\"alias_or_null\",\"decision\":\"fraud|legitimate|null\",\"confidence\":0.0,\"reasoning\":\"one sentence\"}'\n",
134
  " )\n",
135
  "\n",
136
+ "\n",
137
+ "def action_to_payload(action):\n",
138
+ " action_name = action.action_type.value\n",
139
+ " if action_name == 'fetch_merchant_profile':\n",
140
+ " return {'action_type': 'investigate', 'investigation_target': 'merchant_profile', 'decision': None, 'confidence': None, 'reasoning': action.reasoning or 'Review seller risk signals before routing.'}\n",
141
+ " if action_name == 'fetch_customer_profile':\n",
142
+ " return {'action_type': 'investigate', 'investigation_target': 'customer_profile', 'decision': None, 'confidence': None, 'reasoning': action.reasoning or 'Review buyer risk signals before routing.'}\n",
143
+ " if action_name == 'fetch_network_graph':\n",
144
+ " return {'action_type': 'investigate', 'investigation_target': 'network_graph', 'decision': None, 'confidence': None, 'reasoning': action.reasoning or 'Check linked network risk before routing.'}\n",
145
+ " if action_name == 'review_transaction':\n",
146
+ " return {'action_type': 'investigate', 'investigation_target': 'payment_trace', 'decision': None, 'confidence': None, 'reasoning': action.reasoning or 'Inspect payment and fulfillment details first.'}\n",
147
+ " if action_name == 'check_policy':\n",
148
+ " return {'action_type': 'investigate', 'investigation_target': 'policy_review', 'decision': None, 'confidence': None, 'reasoning': action.reasoning or 'Check routing policy before a final decision.'}\n",
149
+ " if action_name == 'add_case_note':\n",
150
+ " return {'action_type': 'decide', 'investigation_target': None, 'decision': 'fraud', 'confidence': 0.55, 'reasoning': action.note_text or 'Document the case before final routing.'}\n",
151
+ "\n",
152
+ " decision = 'fraud' if action.resolution.value in {'block', 'hold', 'escalate'} else 'legitimate'\n",
153
+ " confidence = 0.9 if action.resolution.value in {'approve', 'block'} else 0.7\n",
154
+ " return {\n",
155
+ " 'action_type': 'decide',\n",
156
  " 'investigation_target': None,\n",
157
+ " 'decision': decision,\n",
158
+ " 'confidence': confidence,\n",
159
+ " 'reasoning': action.reasoning or f'Final routing is {action.resolution.value}.',\n",
160
  " }\n",
161
+ "\n",
162
+ "\n",
163
+ "def build_fraudshield_rollout_dataset(per_task_episodes=24):\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  " agent = SnapshotCalibratedFraudDetectionAgent()\n",
165
+ " rows = []\n",
166
  " for task_name in ('easy', 'medium', 'hard'):\n",
167
+ " for _ in range(per_task_episodes):\n",
168
+ " reset_result = env.reset(task=task_name)\n",
169
+ " observation = reset_result.observation\n",
170
+ " done = False\n",
171
+ " while not done:\n",
172
  " action = agent.decide(observation)\n",
173
+ " payload = action_to_payload(action)\n",
174
+ " rows.append({\n",
175
+ " 'text': prompt_from_observation(observation) + '\\n' + json.dumps(payload, separators=(',', ':')),\n",
176
+ " 'source': 'fraudshield_rollout',\n",
177
  " 'task_name': task_name,\n",
 
 
 
178
  " })\n",
179
+ " step_result = env.step(action)\n",
180
+ " observation = step_result.observation\n",
181
+ " done = step_result.done\n",
182
+ " return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)\n",
183
+ "\n",
184
+ "\n",
185
+ "def public_row_to_training_example(row):\n",
186
+ " amount = float(row.get('amount', 0.0) or 0.0)\n",
187
+ " transaction_type = str(row.get('transaction_type', row.get('type', 'purchase')))\n",
188
+ " location = str(row.get('location', 'unknown'))\n",
189
+ " merchant = str(row.get('merchant', row.get('nameDest', 'unknown_merchant')))\n",
190
+ " device = str(row.get('device', 'unknown_device'))\n",
191
+ " payment_method = str(row.get('payment_method', row.get('transaction_type', 'card')))\n",
192
+ " timestamp = str(row.get('timestamp', row.get('step', 'unknown_time')))\n",
193
+ " is_fraud = int(row.get('is_fraud', row.get('isFraud', row.get('Class', 0))) or 0)\n",
194
+ "\n",
195
+ " high_amount = amount >= 1500\n",
196
+ " risky_type = transaction_type.lower() in {'transfer', 'cash_out', 'wire', 'crypto', 'gift_card'}\n",
197
+ " risky_location = any(token in location.lower() for token in ['proxy', 'unknown', 'foreign', 'vpn'])\n",
198
+ " risky_device = any(token in device.lower() for token in ['emulator', 'root', 'shared', 'new'])\n",
199
+ "\n",
200
+ " available = ['merchant_profile', 'customer_profile', 'network_graph', 'payment_trace', 'policy_review']\n",
201
+ " visible_observation = {\n",
202
+ " 'case_id': f\"public_case_{abs(hash(str(row.get('transaction_id', merchant)))) % 100000}\",\n",
203
+ " 'task_name': 'medium',\n",
204
+ " 'current_screen': 'Queue',\n",
205
+ " 'visible_panels': ['triage_summary'],\n",
206
+ " 'revealed_evidence': {},\n",
207
+ " 'linked_case_ids': [],\n",
208
+ " 'remaining_steps': 6,\n",
209
+ " 'remaining_sla': 5,\n",
210
+ " 'note_required': False,\n",
211
+ " 'allowed_actions': ['review_transaction', 'fetch_customer_profile', 'fetch_merchant_profile', 'fetch_network_graph', 'check_policy', 'resolve_case'],\n",
212
+ " 'case_summary': {\n",
213
+ " 'amount_usd': round(amount, 2),\n",
214
+ " 'queue_reason': f'{transaction_type} transaction flagged for manual review',\n",
215
+ " 'visible_risk_band': 'review',\n",
216
+ " 'merchant_region': 'masked',\n",
217
+ " },\n",
218
+ " 'app_context': {\n",
219
+ " 'item_category': transaction_type,\n",
220
+ " 'timestamp': timestamp,\n",
221
+ " 'available_investigations': available,\n",
222
+ " 'public_signals': {\n",
223
+ " 'merchant': merchant,\n",
224
+ " 'device': device,\n",
225
+ " 'payment_method': payment_method,\n",
226
+ " 'location': location,\n",
227
+ " },\n",
228
+ " },\n",
229
+ " }\n",
230
+ "\n",
231
+ " if is_fraud and (risky_type or high_amount):\n",
232
+ " payload = {\n",
233
+ " 'action_type': 'investigate',\n",
234
+ " 'investigation_target': 'network_graph' if risky_device or risky_location else 'payment_trace',\n",
235
+ " 'decision': None,\n",
236
+ " 'confidence': None,\n",
237
+ " 'reasoning': 'The visible signals are suspicious, so gather network or payment evidence before routing.',\n",
238
+ " }\n",
239
+ " elif not is_fraud and amount < 200 and not risky_type:\n",
240
+ " payload = {\n",
241
+ " 'action_type': 'decide',\n",
242
+ " 'investigation_target': None,\n",
243
+ " 'decision': 'legitimate',\n",
244
+ " 'confidence': 0.82,\n",
245
+ " 'reasoning': 'The visible transaction looks low risk and can be cleared with high confidence.',\n",
246
+ " }\n",
247
+ " else:\n",
248
+ " payload = {\n",
249
+ " 'action_type': 'investigate',\n",
250
+ " 'investigation_target': 'merchant_profile' if high_amount else 'customer_profile',\n",
251
+ " 'decision': None,\n",
252
+ " 'confidence': None,\n",
253
+ " 'reasoning': 'The transaction is ambiguous, so inspect merchant or customer history before routing.',\n",
254
+ " }\n",
255
+ "\n",
256
+ " prompt = (\n",
257
+ " 'You are a fraud analyst learning how to investigate suspicious payments. '\n",
258
+ " 'Use visible triage signals to choose the next best FraudShield action. Return JSON only.\\n\\n'\n",
259
+ " f'Visible observation:\\n{json.dumps(visible_observation, sort_keys=True)}\\n\\n'\n",
260
+ " f'Valid investigation aliases: {available}.\\n'\n",
261
+ " 'Respond with JSON using this schema: '\n",
262
+ " '{\"action_type\":\"investigate|decide\",\"investigation_target\":\"alias_or_null\",\"decision\":\"fraud|legitimate|null\",\"confidence\":0.0,\"reasoning\":\"one sentence\"}'\n",
263
+ " )\n",
264
+ " return {'text': prompt + '\\n' + json.dumps(payload, separators=(',', ':')), 'source': 'public_fraud_data', 'task_name': 'curriculum'}\n",
265
+ "\n",
266
  "\n",
267
+ "def build_public_curriculum(max_rows=2500):\n",
268
+ " dataset_name = 'Phoenix21/mock_fraud-detection-dataset'\n",
269
+ " public_ds = load_dataset(dataset_name, split='train')\n",
270
+ " rows = [public_row_to_training_example(row) for row in public_ds.shuffle(seed=42).select(range(min(max_rows, len(public_ds))))]\n",
271
+ " print('Loaded public curriculum rows from', dataset_name, 'count=', len(rows))\n",
272
+ " return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)\n",
273
+ "\n",
274
+ "\n",
275
+ "public_dataset = build_public_curriculum(max_rows=2500)\n",
276
+ "fraudshield_dataset = build_fraudshield_rollout_dataset(per_task_episodes=24)\n",
277
+ "print(public_dataset)\n",
278
+ "print(fraudshield_dataset)\n",
279
+ "print(public_dataset[0]['text'][:900])\n",
280
+ "print(fraudshield_dataset[0]['text'][:900])\n"
281
  ]
282
  },
283
  {
 
317
  "from transformers import TrainingArguments\n",
318
  "from trl import SFTTrainer\n",
319
  "\n",
320
+ "stage1_args = TrainingArguments(\n",
321
+ " output_dir='fraudshield-sft-stage1',\n",
322
+ " num_train_epochs=1,\n",
323
  " per_device_train_batch_size=2,\n",
324
  " gradient_accumulation_steps=4,\n",
325
  " learning_rate=2e-4,\n",
326
+ " logging_steps=5,\n",
327
+ " save_strategy='no',\n",
328
+ " report_to='none',\n",
329
+ " fp16=not torch.cuda.is_bf16_supported(),\n",
330
+ " bf16=torch.cuda.is_bf16_supported(),\n",
331
+ " warmup_ratio=0.03,\n",
332
+ " lr_scheduler_type='cosine',\n",
333
+ ")\n",
334
+ "\n",
335
+ "stage1_trainer = SFTTrainer(\n",
336
+ " model=model,\n",
337
+ " tokenizer=tokenizer,\n",
338
+ " train_dataset=public_dataset,\n",
339
+ " dataset_text_field='text',\n",
340
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
341
+ " packing=False,\n",
342
+ " args=stage1_args,\n",
343
+ ")\n",
344
+ "\n",
345
+ "stage1_trainer.train()\n",
346
+ "\n",
347
+ "stage2_args = TrainingArguments(\n",
348
+ " output_dir='fraudshield-sft-stage2',\n",
349
+ " num_train_epochs=2,\n",
350
+ " per_device_train_batch_size=2,\n",
351
+ " gradient_accumulation_steps=4,\n",
352
+ " learning_rate=1e-4,\n",
353
  " logging_steps=1,\n",
354
  " save_strategy='epoch',\n",
355
  " report_to='none',\n",
356
  " fp16=not torch.cuda.is_bf16_supported(),\n",
357
  " bf16=torch.cuda.is_bf16_supported(),\n",
 
358
  " warmup_ratio=0.05,\n",
359
  " lr_scheduler_type='cosine',\n",
360
  ")\n",
361
  "\n",
362
  "trainer = SFTTrainer(\n",
363
+ " model=stage1_trainer.model,\n",
364
  " tokenizer=tokenizer,\n",
365
+ " train_dataset=fraudshield_dataset,\n",
366
  " dataset_text_field='text',\n",
367
  " max_seq_length=MAX_SEQ_LENGTH,\n",
368
  " packing=False,\n",
369
+ " args=stage2_args,\n",
370
  ")\n",
371
  "\n",
372
  "trainer.train()\n",
 
395
  " check=True,\n",
396
  " )\n",
397
  " with open('fraudshield_baseline_results.json', 'r', encoding='utf-8') as handle:\n",
398
+ " results = json.load(handle)\n",
399
+ " return results, completed.stdout\n",
400
  "\n",
401
  "baseline_results, baseline_stdout = run_inference()\n",
402
  "trained_results, trained_stdout = run_inference({'LOCAL_MODEL_PATH': OUTPUT_DIR})\n",
 
410
  " 'delta': trained_results[task_name]['score'] - baseline_results[task_name]['score'],\n",
411
  " })\n",
412
  "\n",
413
+ "print('Heuristic baseline stdout:\n",
414
+ "', baseline_stdout)\n",
415
+ "print('Trained model stdout:\n",
416
+ "', trained_stdout)\n",
417
  "print(json.dumps(comparison_rows, indent=2))\n"
418
  ]
419
  },
 
453
  "summary = {\n",
454
  " 'status': 'completed',\n",
455
  " 'updated_at': datetime.utcnow().isoformat() + 'Z',\n",
456
+ " 'trainer': 'Two-stage TRL SFTTrainer with Unsloth LoRA',\n",
457
  " 'base_model': MODEL_NAME,\n",
458
+ " 'public_curriculum_dataset': 'Phoenix21/mock_fraud-detection-dataset',\n",
459
  " 'local_model_path': OUTPUT_DIR,\n",
460
  " 'baseline': {\n",
461
  " 'easy': baseline_results['easy']['score'],\n",
 
481
  "with open('training_summary.json', 'w', encoding='utf-8') as handle:\n",
482
  " json.dump(summary, handle, indent=2)\n",
483
  "\n",
484
+ "print(json.dumps(summary, indent=2))\n",
485
+ "print('Artifacts saved: reward_curve.png, loss_curve.png, training_summary.json')\n"
486
  ]
487
  }
488
  ],
 
495
  "language_info": {
496
  "name": "python",
497
  "version": "3.12"
498
+ },
499
+ "colab": {
500
+ "name": "fraudshield_trl_colab.ipynb",
501
+ "provenance": []
502
  }
503
  },
504
  "nbformat": 4,
505
  "nbformat_minor": 5
506
+ }