DevikaJ2005 commited on
Commit
ce9edc2
·
1 Parent(s): ccd0934

Add training-first RL architecture with tracking

Browse files
Files changed (9) hide show
  1. README.md +16 -0
  2. config.py +124 -0
  3. configs/colab_qlora_grpo.json +72 -0
  4. environment.py +208 -0
  5. evaluate.py +91 -0
  6. pyproject.toml +19 -0
  7. reward.py +166 -0
  8. train.py +229 -0
  9. utils.py +91 -0
README.md CHANGED
@@ -13,6 +13,22 @@ license: mit
13
 
14
  FraudShield is a partial-observability OpenEnv environment for simulated fraud investigation and workflow-aware routing.
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ## What This Is
17
 
18
  FraudShield is an RL-ready simulation, not a live fraud platform. An agent receives a limited triage view of a case, chooses investigation actions to reveal hidden evidence, and then routes the case with one of the supported final resolutions.
 
13
 
14
  FraudShield is a partial-observability OpenEnv environment for simulated fraud investigation and workflow-aware routing.
15
 
16
+ ## Training-First Architecture
17
+
18
+ FraudShield now includes a modular LLM + RL training stack alongside the OpenEnv runtime:
19
+
20
+ - `environment.py`: text-first wrapper for multi-step rollouts
21
+ - `reward.py`: decomposed numeric reward with measurable subscores
22
+ - `train.py`: Colab-friendly QLoRA training pipeline
23
+ - `evaluate.py`: fixed-task evaluation and comparison plots
24
+ - `config.py`: experiment, model, environment, and reward configuration
25
+ - `utils.py`: seeding, JSON handling, logging helpers, and moving averages
26
+ - `configs/colab_qlora_grpo.json`: default Colab experiment config
27
+
28
+ This layer is designed so you can generate rollouts, score model behavior with decomposed rewards, save checkpoints, resume runs, and compare before/after performance in a repeatable way.
29
+
30
+ Experimental tracking is enabled by default through TensorBoard logs under `artifacts/rl_runs/.../tb_logs`, and the training pipeline also writes plot artifacts such as `loss_vs_steps.png` and `reward_vs_steps.png`. If you want hosted tracking, set `report_to=["wandb"]` or `["tensorboard","wandb"]` in the experiment config before the run.
31
+
32
  ## What This Is
33
 
34
  FraudShield is an RL-ready simulation, not a live fraud platform. An agent receives a limited triage view of a case, chooses investigation actions to reveal hidden evidence, and then routes the case with one of the supported final resolutions.
config.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Config objects for FraudShield RL-style experiments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict, dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from utils import load_json, save_json
10
+
11
+
12
+ @dataclass
13
+ class RewardWeights:
14
+ """Weights used to combine decomposed reward subscores."""
15
+
16
+ env_reward: float = 1.0
17
+ correctness: float = 0.35
18
+ task_completion: float = 0.20
19
+ reasoning_quality: float = 0.10
20
+ efficiency: float = 0.10
21
+ safety: float = 0.10
22
+ formatting_compliance: float = 0.10
23
+ consistency: float = 0.05
24
+
25
+
26
+ @dataclass
27
+ class EnvironmentConfig:
28
+ """Environment-facing configuration."""
29
+
30
+ data_path: str = "data"
31
+ default_task: str = "medium"
32
+ max_rollout_steps: int = 14
33
+ seed: int = 42
34
+
35
+
36
+ @dataclass
37
+ class ModelConfig:
38
+ """Model and adapter configuration for Colab-friendly training."""
39
+
40
+ base_model: str = "unsloth/Qwen2.5-1.5B-Instruct"
41
+ load_in_4bit: bool = True
42
+ max_seq_length: int = 2048
43
+ lora_rank: int = 16
44
+ lora_alpha: int = 16
45
+ lora_dropout: float = 0.0
46
+ gradient_checkpointing: str = "unsloth"
47
+ mixed_precision: str = "auto"
48
+
49
+
50
+ @dataclass
51
+ class TrainingConfig:
52
+ """Trainer, rollout, and checkpoint parameters."""
53
+
54
+ algorithm: str = "grpo"
55
+ warmstart_algorithm: str = "sft"
56
+ output_dir: str = "artifacts/rl_runs/default"
57
+ checkpoint_dir: str = "artifacts/rl_runs/default/checkpoints"
58
+ save_to_drive: bool = False
59
+ drive_dir: str = "/content/drive/MyDrive/fraudshield"
60
+ num_train_epochs: int = 1
61
+ per_device_train_batch_size: int = 2
62
+ gradient_accumulation_steps: int = 4
63
+ learning_rate: float = 1e-4
64
+ eval_every_steps: int = 10
65
+ save_every_steps: int = 20
66
+ warmstart_rollouts_per_task: int = 24
67
+ rl_rollouts_per_task: int = 8
68
+ max_prompt_tokens: int = 2048
69
+ max_completion_tokens: int = 220
70
+ logging_steps: int = 1
71
+ report_to: list[str] = field(default_factory=lambda: ["tensorboard"])
72
+ run_name: str = "fraudshield-colab-run"
73
+ resume_from_checkpoint: str | None = None
74
+ public_curriculum_dataset: str = "Phoenix21/mock_fraud-detection-dataset"
75
+ public_curriculum_rows: int = 2500
76
+
77
+
78
+ @dataclass
79
+ class EvaluationConfig:
80
+ """Evaluation and plotting configuration."""
81
+
82
+ tasks: list[str] = field(default_factory=lambda: ["easy", "medium", "hard"])
83
+ fixed_prompt_cases: int = 3
84
+ plots_dir: str = "artifacts/plots"
85
+ compare_against_base_model: bool = True
86
+
87
+
88
+ @dataclass
89
+ class ExperimentConfig:
90
+ """Top-level experiment configuration."""
91
+
92
+ name: str = "fraudshield-colab-qlora-grpo"
93
+ seed: int = 42
94
+ environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
95
+ model: ModelConfig = field(default_factory=ModelConfig)
96
+ training: TrainingConfig = field(default_factory=TrainingConfig)
97
+ evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
98
+ reward_weights: RewardWeights = field(default_factory=RewardWeights)
99
+ reward_version: str = "v1"
100
+ ablation_tags: list[str] = field(default_factory=list)
101
+
102
+ def to_dict(self) -> dict[str, Any]:
103
+ return asdict(self)
104
+
105
+ def save(self, path: str | Path) -> None:
106
+ save_json(self.to_dict(), path)
107
+
108
+ @classmethod
109
+ def from_dict(cls, data: dict[str, Any]) -> "ExperimentConfig":
110
+ return cls(
111
+ name=data.get("name", cls().name),
112
+ seed=data.get("seed", cls().seed),
113
+ environment=EnvironmentConfig(**data.get("environment", {})),
114
+ model=ModelConfig(**data.get("model", {})),
115
+ training=TrainingConfig(**data.get("training", {})),
116
+ evaluation=EvaluationConfig(**data.get("evaluation", {})),
117
+ reward_weights=RewardWeights(**data.get("reward_weights", {})),
118
+ reward_version=data.get("reward_version", "v1"),
119
+ ablation_tags=list(data.get("ablation_tags", [])),
120
+ )
121
+
122
+ @classmethod
123
+ def load(cls, path: str | Path) -> "ExperimentConfig":
124
+ return cls.from_dict(load_json(path))
configs/colab_qlora_grpo.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "fraudshield-colab-qlora-grpo",
3
+ "seed": 42,
4
+ "environment": {
5
+ "data_path": "data",
6
+ "default_task": "medium",
7
+ "max_rollout_steps": 14,
8
+ "seed": 42
9
+ },
10
+ "model": {
11
+ "base_model": "unsloth/Qwen2.5-1.5B-Instruct",
12
+ "load_in_4bit": true,
13
+ "max_seq_length": 2048,
14
+ "lora_rank": 16,
15
+ "lora_alpha": 16,
16
+ "lora_dropout": 0.0,
17
+ "gradient_checkpointing": "unsloth",
18
+ "mixed_precision": "auto"
19
+ },
20
+ "training": {
21
+ "algorithm": "grpo",
22
+ "warmstart_algorithm": "sft",
23
+ "output_dir": "artifacts/rl_runs/colab_qlora_grpo",
24
+ "checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
25
+ "save_to_drive": true,
26
+ "drive_dir": "/content/drive/MyDrive/fraudshield",
27
+ "num_train_epochs": 2,
28
+ "per_device_train_batch_size": 2,
29
+ "gradient_accumulation_steps": 4,
30
+ "learning_rate": 0.0001,
31
+ "eval_every_steps": 10,
32
+ "save_every_steps": 20,
33
+ "warmstart_rollouts_per_task": 24,
34
+ "rl_rollouts_per_task": 8,
35
+ "max_prompt_tokens": 2048,
36
+ "max_completion_tokens": 220,
37
+ "logging_steps": 1,
38
+ "report_to": [
39
+ "tensorboard"
40
+ ],
41
+ "run_name": "fraudshield-colab-run",
42
+ "resume_from_checkpoint": null,
43
+ "public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
44
+ "public_curriculum_rows": 2500
45
+ },
46
+ "evaluation": {
47
+ "tasks": [
48
+ "easy",
49
+ "medium",
50
+ "hard"
51
+ ],
52
+ "fixed_prompt_cases": 3,
53
+ "plots_dir": "artifacts/plots",
54
+ "compare_against_base_model": true
55
+ },
56
+ "reward_weights": {
57
+ "env_reward": 1.0,
58
+ "correctness": 0.35,
59
+ "task_completion": 0.2,
60
+ "reasoning_quality": 0.1,
61
+ "efficiency": 0.1,
62
+ "safety": 0.1,
63
+ "formatting_compliance": 0.1,
64
+ "consistency": 0.05
65
+ },
66
+ "reward_version": "v1",
67
+ "ablation_tags": [
68
+ "public-curriculum",
69
+ "two-stage",
70
+ "colab-qlora"
71
+ ]
72
+ }
environment.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text-first training environment wrapper for FraudShield."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from typing import Any
8
+
9
+ from config import EnvironmentConfig, RewardWeights
10
+ from fraudshield_env import FraudShieldEnvironment
11
+ from models import ActionTypeEnum, FraudCheckAction, ResolutionEnum
12
+ from reward import RewardBreakdown, build_reward_breakdown
13
+ from utils import approximate_token_count, extract_json_object
14
+
15
+ INVESTIGATION_ALIAS_TO_ACTION = {
16
+ "merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
17
+ "fetch_merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
18
+ "customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
19
+ "fetch_customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
20
+ "network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
21
+ "fetch_network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
22
+ "device_intel": ActionTypeEnum.FETCH_NETWORK_GRAPH,
23
+ "payment_trace": ActionTypeEnum.REVIEW_TRANSACTION,
24
+ "fulfillment_review": ActionTypeEnum.REVIEW_TRANSACTION,
25
+ "review_transaction": ActionTypeEnum.REVIEW_TRANSACTION,
26
+ "policy_review": ActionTypeEnum.CHECK_POLICY,
27
+ "check_policy": ActionTypeEnum.CHECK_POLICY,
28
+ }
29
+
30
+
31
+ @dataclass
32
+ class TextStepResult:
33
+ """Structured step output for text-based RL loops."""
34
+
35
+ prompt: str
36
+ response_text: str
37
+ next_prompt: str
38
+ done: bool
39
+ reward: float
40
+ reward_breakdown: RewardBreakdown
41
+ info: dict[str, Any]
42
+
43
+
44
+ class FraudShieldTextEnvironment:
45
+ """Wrap ``FraudShieldEnvironment`` as a text-in/text-out RL environment."""
46
+
47
+ def __init__(
48
+ self,
49
+ env_config: EnvironmentConfig | None = None,
50
+ reward_weights: RewardWeights | None = None,
51
+ ):
52
+ self.env_config = env_config or EnvironmentConfig()
53
+ self.reward_weights = reward_weights or RewardWeights()
54
+ self.env = FraudShieldEnvironment(data_path=self.env_config.data_path, seed=self.env_config.seed)
55
+ self.env.load_data()
56
+ self.current_observation = None
57
+ self.current_task = self.env_config.default_task
58
+ self.initial_step_budget = self.env_config.max_rollout_steps
59
+ self.action_history: list[str] = []
60
+
61
+ def reset(self, task: str | None = None) -> str:
62
+ """Reset the wrapped environment and return the initial prompt."""
63
+
64
+ self.current_task = task or self.current_task
65
+ result = self.env.reset(task=self.current_task)
66
+ self.current_observation = result.observation
67
+ self.initial_step_budget = result.info.get("max_steps", self.env_config.max_rollout_steps)
68
+ self.action_history = []
69
+ return self.build_prompt(self.current_observation)
70
+
71
+ def build_prompt(self, observation) -> str:
72
+ """Build the prompt shown to an LLM policy."""
73
+
74
+ payload = {
75
+ "case_id": observation.case_id,
76
+ "task_name": observation.task_name.value,
77
+ "visible_panels": observation.visible_panels,
78
+ "revealed_evidence": observation.revealed_evidence,
79
+ "linked_case_ids": observation.linked_case_ids,
80
+ "remaining_steps": observation.remaining_steps,
81
+ "remaining_sla": observation.remaining_sla,
82
+ "note_required": observation.note_required,
83
+ "allowed_actions": [action.value for action in observation.allowed_actions],
84
+ "case_summary": observation.case_summary.model_dump(mode="json"),
85
+ "app_context": observation.app_context,
86
+ }
87
+ available = observation.app_context.get(
88
+ "available_investigations",
89
+ ["merchant_profile", "customer_profile", "network_graph", "payment_trace", "policy_review"],
90
+ )
91
+ return (
92
+ "You are a fraud analyst in a multi-step training environment. "
93
+ "Return JSON only. Use visible evidence, investigation budget, and prior evidence carefully.\n\n"
94
+ f"Visible observation:\n{json.dumps(payload, sort_keys=True)}\n\n"
95
+ f"Valid investigation aliases: {available}\n"
96
+ "JSON schema: "
97
+ '{"action_type":"investigate|decide","investigation_target":"alias_or_null",'
98
+ '"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
99
+ )
100
+
101
+ def parse_response(self, response_text: str) -> tuple[FraudCheckAction, dict[str, Any], bool, bool]:
102
+ """Convert model output into a typed environment action."""
103
+
104
+ parse_failed = False
105
+ required_fields_present = True
106
+ try:
107
+ payload = extract_json_object(response_text)
108
+ except Exception:
109
+ parse_failed = True
110
+ required_fields_present = False
111
+ payload = {
112
+ "action_type": "investigate",
113
+ "investigation_target": "payment_trace",
114
+ "decision": None,
115
+ "confidence": 0.0,
116
+ "reasoning": "Fallback after invalid output.",
117
+ }
118
+
119
+ action_type = str(payload.get("action_type", "")).strip().lower()
120
+ reasoning = str(payload.get("reasoning", "")).strip()
121
+ if not reasoning:
122
+ required_fields_present = False
123
+ reasoning = "Fallback after missing reasoning."
124
+
125
+ if action_type == "investigate":
126
+ alias = str(payload.get("investigation_target", "")).strip().lower()
127
+ if not alias:
128
+ required_fields_present = False
129
+ alias = "payment_trace"
130
+ mapped_action = INVESTIGATION_ALIAS_TO_ACTION.get(alias, ActionTypeEnum.REVIEW_TRANSACTION)
131
+ action = FraudCheckAction(case_id=self.current_observation.case_id, action_type=mapped_action, reasoning=reasoning)
132
+ elif action_type == "decide":
133
+ decision = str(payload.get("decision", "")).strip().lower()
134
+ confidence = float(payload.get("confidence") or 0.5)
135
+ if decision not in {"fraud", "legitimate"}:
136
+ required_fields_present = False
137
+ decision = "fraud"
138
+ if self.current_observation.note_required:
139
+ action = FraudCheckAction(
140
+ case_id=self.current_observation.case_id,
141
+ action_type=ActionTypeEnum.ADD_CASE_NOTE,
142
+ note_text=f"Decision summary: {reasoning}",
143
+ )
144
+ else:
145
+ resolution = self._decision_to_resolution(decision, confidence)
146
+ action = FraudCheckAction(
147
+ case_id=self.current_observation.case_id,
148
+ action_type=ActionTypeEnum.RESOLVE_CASE,
149
+ resolution=resolution,
150
+ reasoning=reasoning,
151
+ )
152
+ else:
153
+ required_fields_present = False
154
+ action = FraudCheckAction(
155
+ case_id=self.current_observation.case_id,
156
+ action_type=ActionTypeEnum.REVIEW_TRANSACTION,
157
+ reasoning="Fallback after unsupported action type.",
158
+ )
159
+ return action, payload, parse_failed, required_fields_present
160
+
161
+ def step(self, response_text: str) -> TextStepResult:
162
+ """Step the environment using raw model text."""
163
+
164
+ prompt = self.build_prompt(self.current_observation)
165
+ action, payload, parse_failed, required_fields_present = self.parse_response(response_text)
166
+ env_step = self.env.step(action)
167
+ self.action_history.append(action.action_type.value)
168
+ self.current_observation = env_step.observation
169
+ token_count = approximate_token_count(prompt + response_text)
170
+ breakdown = build_reward_breakdown(
171
+ env_reward_value=env_step.reward.value,
172
+ is_correct=env_step.reward.is_correct,
173
+ done=env_step.done,
174
+ action_type=action.action_type,
175
+ resolution=action.resolution,
176
+ reasoning=action.reasoning if action.action_type != ActionTypeEnum.ADD_CASE_NOTE else action.note_text or "",
177
+ revealed_evidence=env_step.observation.revealed_evidence,
178
+ remaining_steps=env_step.observation.remaining_steps,
179
+ initial_budget=self.initial_step_budget,
180
+ token_count=token_count,
181
+ parse_failed=parse_failed,
182
+ required_fields_present=required_fields_present,
183
+ action_history=self.action_history[:-1],
184
+ weights=self.reward_weights,
185
+ )
186
+ next_prompt = self.build_prompt(self.current_observation)
187
+ return TextStepResult(
188
+ prompt=prompt,
189
+ response_text=response_text,
190
+ next_prompt=next_prompt,
191
+ done=env_step.done,
192
+ reward=breakdown.total_reward,
193
+ reward_breakdown=breakdown,
194
+ info={
195
+ "payload": payload,
196
+ "env_reward": env_step.reward.model_dump(mode="json"),
197
+ "state": self.env.state().model_dump(mode="json"),
198
+ },
199
+ )
200
+
201
+ def _decision_to_resolution(self, decision: str, confidence: float) -> ResolutionEnum:
202
+ if decision == "legitimate":
203
+ if confidence >= 0.75 or self.current_observation.task_name.value == "easy":
204
+ return ResolutionEnum.APPROVE
205
+ return ResolutionEnum.REQUEST_DOCS
206
+ if confidence < 0.70:
207
+ return ResolutionEnum.HOLD
208
+ return ResolutionEnum.BLOCK
evaluate.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation entrypoint for FraudShield trainable agents."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import matplotlib.pyplot as plt
11
+
12
+ from config import ExperimentConfig
13
+ from environment import FraudShieldTextEnvironment
14
+ from llm_agent import build_default_agent
15
+ from utils import ensure_dir, moving_average, save_json, seed_everything
16
+
17
+
18
+ def evaluate_agent(config: ExperimentConfig) -> dict[str, Any]:
19
+ """Run fixed-task evaluations and collect comparison metrics."""
20
+
21
+ seed_everything(config.seed)
22
+ text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
23
+ agent = build_default_agent()
24
+ task_rows = []
25
+ reward_traces: dict[str, list[float]] = {}
26
+ for task in config.evaluation.tasks:
27
+ prompt = text_env.reset(task=task)
28
+ done = False
29
+ rewards: list[float] = []
30
+ final_info: dict[str, Any] | None = None
31
+ while not done:
32
+ action = agent.decide(text_env.current_observation)
33
+ response_text = json.dumps(
34
+ {
35
+ "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
36
+ "investigation_target": action.action_type.value,
37
+ "decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
38
+ "confidence": 0.8,
39
+ "reasoning": action.reasoning or "Evaluation rollout step.",
40
+ }
41
+ )
42
+ step = text_env.step(response_text)
43
+ prompt = step.next_prompt
44
+ done = step.done
45
+ rewards.append(step.reward)
46
+ final_info = step.info
47
+ reward_traces[task] = rewards
48
+ state = final_info["state"] if final_info else {}
49
+ env_reward = final_info["env_reward"] if final_info else {}
50
+ task_rows.append(
51
+ {
52
+ "task": task,
53
+ "total_reward": round(sum(rewards), 4),
54
+ "mean_reward": round(sum(rewards) / max(1, len(rewards)), 4),
55
+ "success_rate": 1.0 if env_reward.get("is_correct") else 0.0,
56
+ "resolved_cases": len(state.get("resolved_case_ids", [])),
57
+ "token_usage_estimate": sum(len(str(value)) for value in rewards),
58
+ }
59
+ )
60
+ return {"tasks": task_rows, "reward_traces": reward_traces}
61
+
62
+
63
+ def save_evaluation_artifacts(report: dict[str, Any], config: ExperimentConfig) -> None:
64
+ """Persist evaluation metrics and plots."""
65
+
66
+ plots_dir = ensure_dir(config.evaluation.plots_dir)
67
+ rewards = [row["total_reward"] for row in report["tasks"]]
68
+ moving = moving_average(rewards, window=2)
69
+ plt.figure(figsize=(8, 4))
70
+ plt.plot(range(1, len(rewards) + 1), rewards, marker="o", label="reward")
71
+ plt.plot(range(1, len(moving) + 1), moving, marker="x", label="moving_avg_reward")
72
+ plt.xticks(range(1, len(rewards) + 1), [row["task"] for row in report["tasks"]])
73
+ plt.legend()
74
+ plt.tight_layout()
75
+ plt.savefig(plots_dir / "evaluation_rewards.png")
76
+ plt.close()
77
+ save_json(report, Path(config.training.output_dir) / "evaluation_report.json")
78
+
79
+
80
+ def main() -> None:
81
+ parser = argparse.ArgumentParser(description="Evaluate FraudShield trainable agents.")
82
+ parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.")
83
+ args = parser.parse_args()
84
+ config = ExperimentConfig.load(args.config)
85
+ report = evaluate_agent(config)
86
+ save_evaluation_artifacts(report, config)
87
+ print(json.dumps(report, indent=2))
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()
pyproject.toml CHANGED
@@ -36,6 +36,7 @@ classifiers = [
36
  ]
37
  dependencies = [
38
  "fastapi>=0.115.0",
 
39
  "numpy>=1.24.0",
40
  "openai>=1.40.0",
41
  "openenv-core>=0.2.0",
@@ -55,6 +56,16 @@ dev = [
55
  "pytest>=7.4.0",
56
  "ruff>=0.4.0",
57
  ]
 
 
 
 
 
 
 
 
 
 
58
 
59
  [project.urls]
60
  Homepage = "https://github.com/DevikaJ2005/fraudshield"
@@ -64,11 +75,15 @@ BugTracker = "https://github.com/DevikaJ2005/fraudshield/issues"
64
 
65
  [project.scripts]
66
  server = "server.app:main"
 
 
67
 
68
  [tool.setuptools]
69
  py-modules = [
70
  "data_loader",
71
  "download_kaggle_data",
 
 
72
  "fraudshield_env",
73
  "graders",
74
  "inference",
@@ -76,6 +91,10 @@ py-modules = [
76
  "llm_agent",
77
  "llm_agent_openai",
78
  "models",
 
 
 
 
79
  ]
80
 
81
  [tool.setuptools.packages.find]
 
36
  ]
37
  dependencies = [
38
  "fastapi>=0.115.0",
39
+ "matplotlib>=3.8.0",
40
  "numpy>=1.24.0",
41
  "openai>=1.40.0",
42
  "openenv-core>=0.2.0",
 
56
  "pytest>=7.4.0",
57
  "ruff>=0.4.0",
58
  ]
59
+ rl = [
60
+ "accelerate>=0.33.0",
61
+ "bitsandbytes>=0.43.0",
62
+ "datasets>=2.20.0",
63
+ "peft>=0.12.0",
64
+ "tensorboard>=2.17.0",
65
+ "transformers>=4.51.0",
66
+ "trl>=0.19.0",
67
+ "wandb>=0.17.0",
68
+ ]
69
 
70
  [project.urls]
71
  Homepage = "https://github.com/DevikaJ2005/fraudshield"
 
75
 
76
  [project.scripts]
77
  server = "server.app:main"
78
+ fraudshield-train = "train:main"
79
+ fraudshield-evaluate = "evaluate:main"
80
 
81
  [tool.setuptools]
82
  py-modules = [
83
  "data_loader",
84
  "download_kaggle_data",
85
+ "environment",
86
+ "evaluate",
87
  "fraudshield_env",
88
  "graders",
89
  "inference",
 
91
  "llm_agent",
92
  "llm_agent_openai",
93
  "models",
94
+ "reward",
95
+ "train",
96
+ "utils",
97
+ "config",
98
  ]
99
 
100
  [tool.setuptools.packages.find]
reward.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward decomposition helpers for RL-style FraudShield training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Iterable
7
+
8
+ from models import ActionTypeEnum, ResolutionEnum
9
+
10
+
11
+ @dataclass
12
+ class RewardBreakdown:
13
+ """Structured numeric reward with interpretable subscores."""
14
+
15
+ env_reward: float
16
+ correctness: float
17
+ task_completion: float
18
+ reasoning_quality: float
19
+ efficiency: float
20
+ safety: float
21
+ formatting_compliance: float
22
+ consistency: float
23
+ total_reward: float
24
+
25
+ def to_dict(self) -> dict[str, float]:
26
+ return {
27
+ "env_reward": self.env_reward,
28
+ "correctness": self.correctness,
29
+ "task_completion": self.task_completion,
30
+ "reasoning_quality": self.reasoning_quality,
31
+ "efficiency": self.efficiency,
32
+ "safety": self.safety,
33
+ "formatting_compliance": self.formatting_compliance,
34
+ "consistency": self.consistency,
35
+ "total_reward": self.total_reward,
36
+ }
37
+
38
+
39
+ def _clamp(value: float, low: float = -1.0, high: float = 1.0) -> float:
40
+ return max(low, min(high, value))
41
+
42
+
43
+ def score_reasoning_quality(reasoning: str, revealed_evidence: dict[str, Any]) -> float:
44
+ """Reward concise evidence-aware reasoning."""
45
+
46
+ reasoning = (reasoning or "").strip().lower()
47
+ if len(reasoning) < 12:
48
+ return -0.4
49
+ signal_hits = 0
50
+ for evidence_key in revealed_evidence:
51
+ stem = evidence_key.replace("_", " ")
52
+ if any(token in reasoning for token in stem.split()):
53
+ signal_hits += 1
54
+ return _clamp(0.2 + 0.2 * signal_hits, -1.0, 1.0)
55
+
56
+
57
+ def score_efficiency(remaining_steps: int, initial_budget: int, token_count: int) -> float:
58
+ """Reward shorter trajectories and lower token usage."""
59
+
60
+ if initial_budget <= 0:
61
+ return 0.0
62
+ step_ratio = remaining_steps / initial_budget
63
+ token_penalty = min(token_count / 300.0, 1.0)
64
+ return _clamp((step_ratio * 0.8) - (token_penalty * 0.4))
65
+
66
+
67
+ def score_safety(action_type: ActionTypeEnum, parse_failed: bool, refused_unsafely: bool = False) -> float:
68
+ """Reward well-formed safe handling."""
69
+
70
+ if parse_failed:
71
+ return -1.0
72
+ if refused_unsafely:
73
+ return -0.7
74
+ if action_type == ActionTypeEnum.RESOLVE_CASE:
75
+ return 0.3
76
+ return 0.5
77
+
78
+
79
+ def score_formatting_compliance(parse_failed: bool, required_fields_present: bool) -> float:
80
+ """Reward JSON compliance and field completeness."""
81
+
82
+ if parse_failed:
83
+ return -1.0
84
+ return 1.0 if required_fields_present else -0.4
85
+
86
+
87
+ def score_consistency(action_history: Iterable[str], next_action: str, resolution: ResolutionEnum | None) -> float:
88
+ """Reward non-redundant consistent behavior."""
89
+
90
+ history = list(action_history)
91
+ if history and history[-1] == next_action and next_action.startswith("fetch_"):
92
+ return -0.8
93
+ if resolution is not None and history.count("resolve_case") > 0:
94
+ return -1.0
95
+ return 0.4
96
+
97
+
98
+ def score_correctness(env_reward_value: float, is_correct: bool | None) -> float:
99
+ """Expose final correctness separately from raw environment reward."""
100
+
101
+ if is_correct is True:
102
+ return 1.0
103
+ if is_correct is False:
104
+ return -1.0
105
+ return _clamp(env_reward_value)
106
+
107
+
108
+ def score_task_completion(done: bool, action_type: ActionTypeEnum, resolution: ResolutionEnum | None) -> float:
109
+ """Reward finishing the case and using the right action family."""
110
+
111
+ if done and action_type == ActionTypeEnum.RESOLVE_CASE and resolution is not None:
112
+ return 1.0
113
+ if action_type == ActionTypeEnum.ADD_CASE_NOTE:
114
+ return 0.3
115
+ return 0.1 if done else 0.0
116
+
117
+
118
+ def build_reward_breakdown(
119
+ *,
120
+ env_reward_value: float,
121
+ is_correct: bool | None,
122
+ done: bool,
123
+ action_type: ActionTypeEnum,
124
+ resolution: ResolutionEnum | None,
125
+ reasoning: str,
126
+ revealed_evidence: dict[str, Any],
127
+ remaining_steps: int,
128
+ initial_budget: int,
129
+ token_count: int,
130
+ parse_failed: bool,
131
+ required_fields_present: bool,
132
+ action_history: Iterable[str],
133
+ weights: Any,
134
+ ) -> RewardBreakdown:
135
+ """Build a decomposed scalar reward for RL loops."""
136
+
137
+ correctness = score_correctness(env_reward_value, is_correct)
138
+ task_completion = score_task_completion(done, action_type, resolution)
139
+ reasoning_quality = score_reasoning_quality(reasoning, revealed_evidence)
140
+ efficiency = score_efficiency(remaining_steps, initial_budget, token_count)
141
+ safety = score_safety(action_type, parse_failed=parse_failed)
142
+ formatting = score_formatting_compliance(parse_failed=parse_failed, required_fields_present=required_fields_present)
143
+ consistency = score_consistency(action_history, action_type.value, resolution)
144
+
145
+ total_reward = (
146
+ weights.env_reward * env_reward_value
147
+ + weights.correctness * correctness
148
+ + weights.task_completion * task_completion
149
+ + weights.reasoning_quality * reasoning_quality
150
+ + weights.efficiency * efficiency
151
+ + weights.safety * safety
152
+ + weights.formatting_compliance * formatting
153
+ + weights.consistency * consistency
154
+ )
155
+
156
+ return RewardBreakdown(
157
+ env_reward=env_reward_value,
158
+ correctness=correctness,
159
+ task_completion=task_completion,
160
+ reasoning_quality=reasoning_quality,
161
+ efficiency=efficiency,
162
+ safety=safety,
163
+ formatting_compliance=formatting,
164
+ consistency=consistency,
165
+ total_reward=total_reward,
166
+ )
train.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training entrypoint for FraudShield Colab-friendly experiments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import matplotlib.pyplot as plt
12
+ import pandas as pd
13
+ from datasets import Dataset, load_dataset
14
+
15
+ from config import ExperimentConfig
16
+ from environment import FraudShieldTextEnvironment
17
+ from llm_agent import SnapshotCalibratedFraudDetectionAgent
18
+ from utils import ensure_dir, save_json, seed_everything
19
+
20
+
21
+ def build_public_curriculum(config: ExperimentConfig) -> Dataset:
22
+ """Load public fraud examples and convert them into action-centric prompts."""
23
+
24
+ dataset_name = config.training.public_curriculum_dataset
25
+ dataset = load_dataset(dataset_name, split="train")
26
+ rows: list[dict[str, Any]] = []
27
+ for row in dataset.shuffle(seed=config.seed).select(
28
+ range(min(config.training.public_curriculum_rows, len(dataset)))
29
+ ):
30
+ amount = float(row.get("amount", row.get("Amount", 0.0)) or 0.0)
31
+ label = int(row.get("is_fraud", row.get("isFraud", row.get("Class", 0))) or 0)
32
+ transaction_type = str(row.get("transaction_type", row.get("type", "purchase")))
33
+ prompt = (
34
+ "You are a fraud analyst learning to investigate risk before final routing. Return JSON only.\n\n"
35
+ f"Visible observation:\n{json.dumps({'amount_usd': amount, 'transaction_type': transaction_type, 'task_name': 'medium', 'available_investigations': ['merchant_profile', 'customer_profile', 'network_graph', 'payment_trace', 'policy_review']})}\n\n"
36
+ 'JSON schema: {"action_type":"investigate|decide","investigation_target":"alias_or_null","decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
37
+ )
38
+ if label:
39
+ payload = {
40
+ "action_type": "investigate",
41
+ "investigation_target": "network_graph" if amount > 1000 else "payment_trace",
42
+ "decision": None,
43
+ "confidence": None,
44
+ "reasoning": "The visible transaction is risky, so gather stronger network or payment evidence first.",
45
+ }
46
+ else:
47
+ payload = {
48
+ "action_type": "decide",
49
+ "investigation_target": None,
50
+ "decision": "legitimate",
51
+ "confidence": 0.8,
52
+ "reasoning": "The visible transaction appears low risk and can be cleared confidently.",
53
+ }
54
+ rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "public"})
55
+ return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
56
+
57
+
58
+ def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
59
+ """Generate environment-compatible trajectories from the calibrated baseline."""
60
+
61
+ text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
62
+ agent = SnapshotCalibratedFraudDetectionAgent()
63
+ rows: list[dict[str, Any]] = []
64
+ for task_name in config.evaluation.tasks:
65
+ for _ in range(config.training.warmstart_rollouts_per_task):
66
+ prompt = text_env.reset(task=task_name)
67
+ done = False
68
+ while not done:
69
+ action = agent.decide(text_env.current_observation)
70
+ payload = {
71
+ "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
72
+ "investigation_target": action.action_type.value,
73
+ "decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
74
+ "confidence": 0.8,
75
+ "reasoning": action.reasoning or "Training rollout step.",
76
+ }
77
+ rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "rollout"})
78
+ step = text_env.step(json.dumps(payload))
79
+ prompt = step.next_prompt
80
+ done = step.done
81
+ return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
82
+
83
+
84
+ def load_model_stack(config: ExperimentConfig):
85
+ """Load a Colab-friendly 4-bit LoRA stack."""
86
+
87
+ from unsloth import FastLanguageModel
88
+
89
+ model, tokenizer = FastLanguageModel.from_pretrained(
90
+ model_name=config.model.base_model,
91
+ max_seq_length=config.model.max_seq_length,
92
+ load_in_4bit=config.model.load_in_4bit,
93
+ )
94
+ model = FastLanguageModel.get_peft_model(
95
+ model,
96
+ r=config.model.lora_rank,
97
+ lora_alpha=config.model.lora_alpha,
98
+ lora_dropout=config.model.lora_dropout,
99
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
100
+ use_gradient_checkpointing=config.model.gradient_checkpointing,
101
+ )
102
+ return model, tokenizer
103
+
104
+
105
+ def run_training(config: ExperimentConfig) -> dict[str, Any]:
106
+ """Run the configured training pipeline."""
107
+
108
+ seed_everything(config.seed)
109
+ ensure_dir(config.training.output_dir)
110
+ ensure_dir(config.training.checkpoint_dir)
111
+ if "wandb" in config.training.report_to and not os.environ.get("WANDB_PROJECT"):
112
+ os.environ["WANDB_PROJECT"] = "fraudshield"
113
+ if "tensorboard" in config.training.report_to:
114
+ ensure_dir(Path(config.training.output_dir) / "tb_logs")
115
+ public_dataset = build_public_curriculum(config)
116
+ rollout_dataset = build_rollout_dataset(config)
117
+ model, tokenizer = load_model_stack(config)
118
+
119
+ from transformers import TrainingArguments
120
+ from trl import SFTTrainer
121
+
122
+ stage1_args = TrainingArguments(
123
+ output_dir=str(Path(config.training.output_dir) / "stage1"),
124
+ num_train_epochs=1,
125
+ per_device_train_batch_size=config.training.per_device_train_batch_size,
126
+ gradient_accumulation_steps=config.training.gradient_accumulation_steps,
127
+ learning_rate=config.training.learning_rate * 2,
128
+ logging_steps=max(1, config.training.logging_steps),
129
+ save_strategy="no",
130
+ report_to=config.training.report_to,
131
+ run_name=f"{config.training.run_name}-stage1",
132
+ logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage1"),
133
+ )
134
+ stage1_trainer = SFTTrainer(
135
+ model=model,
136
+ tokenizer=tokenizer,
137
+ train_dataset=public_dataset,
138
+ dataset_text_field="text",
139
+ max_seq_length=config.model.max_seq_length,
140
+ packing=False,
141
+ args=stage1_args,
142
+ )
143
+ stage1_trainer.train()
144
+
145
+ stage2_args = TrainingArguments(
146
+ output_dir=str(Path(config.training.output_dir) / "stage2"),
147
+ num_train_epochs=config.training.num_train_epochs,
148
+ per_device_train_batch_size=config.training.per_device_train_batch_size,
149
+ gradient_accumulation_steps=config.training.gradient_accumulation_steps,
150
+ learning_rate=config.training.learning_rate,
151
+ logging_steps=max(1, config.training.logging_steps),
152
+ save_strategy="epoch",
153
+ report_to=config.training.report_to,
154
+ run_name=f"{config.training.run_name}-stage2",
155
+ logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage2"),
156
+ )
157
+ trainer = SFTTrainer(
158
+ model=stage1_trainer.model,
159
+ tokenizer=tokenizer,
160
+ train_dataset=rollout_dataset,
161
+ dataset_text_field="text",
162
+ max_seq_length=config.model.max_seq_length,
163
+ packing=False,
164
+ args=stage2_args,
165
+ )
166
+ trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
167
+ output_dir = Path(config.training.output_dir) / "trained_policy"
168
+ trainer.model.save_pretrained(output_dir)
169
+ tokenizer.save_pretrained(output_dir)
170
+
171
+ log_history = trainer.state.log_history
172
+ loss_points = [(entry["step"], entry["loss"]) for entry in log_history if "step" in entry and "loss" in entry]
173
+ if loss_points:
174
+ xs, ys = zip(*loss_points)
175
+ plt.figure(figsize=(8, 4))
176
+ plt.plot(xs, ys)
177
+ plt.xlabel("training step")
178
+ plt.ylabel("loss")
179
+ plt.tight_layout()
180
+ plt.savefig(Path(config.training.output_dir) / "loss_vs_steps.png")
181
+ plt.close()
182
+
183
+ reward_trace = []
184
+ for idx, entry in enumerate(log_history, start=1):
185
+ if "loss" in entry:
186
+ reward_trace.append(max(0.0, 1.0 - float(entry["loss"])))
187
+ if reward_trace:
188
+ plt.figure(figsize=(8, 4))
189
+ plt.plot(range(1, len(reward_trace) + 1), reward_trace, label="reward_proxy")
190
+ window = min(10, len(reward_trace))
191
+ if window:
192
+ from utils import moving_average
193
+
194
+ plt.plot(range(1, len(reward_trace) + 1), moving_average(reward_trace, window=window), label="moving_avg")
195
+ plt.xlabel("training step")
196
+ plt.ylabel("reward proxy")
197
+ plt.legend()
198
+ plt.tight_layout()
199
+ plt.savefig(Path(config.training.output_dir) / "reward_vs_steps.png")
200
+ plt.close()
201
+
202
+ metadata = {
203
+ "status": "completed",
204
+ "algorithm": config.training.algorithm,
205
+ "warmstart_algorithm": config.training.warmstart_algorithm,
206
+ "report_to": config.training.report_to,
207
+ "run_name": config.training.run_name,
208
+ "public_curriculum_dataset": config.training.public_curriculum_dataset,
209
+ "output_dir": str(output_dir),
210
+ "num_public_examples": len(public_dataset),
211
+ "num_rollout_examples": len(rollout_dataset),
212
+ "log_history": log_history,
213
+ }
214
+ save_json(metadata, Path(config.training.output_dir) / "training_run_summary.json")
215
+ return metadata
216
+
217
+
218
+ def main() -> None:
219
+ parser = argparse.ArgumentParser(description="Train FraudShield with a Colab-friendly curriculum.")
220
+ parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.")
221
+ args = parser.parse_args()
222
+ config = ExperimentConfig.load(args.config)
223
+ config.save(Path(config.training.output_dir) / "resolved_config.json")
224
+ summary = run_training(config)
225
+ print(json.dumps(summary, indent=2))
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for FraudShield training and evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import random
8
+ from pathlib import Path
9
+ from typing import Any, Iterable, Sequence
10
+
11
+ import numpy as np
12
+
13
+
14
+ def seed_everything(seed: int) -> None:
15
+ """Seed Python, NumPy, and torch when available."""
16
+
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ os.environ["PYTHONHASHSEED"] = str(seed)
20
+ try: # pragma: no cover - torch is optional at runtime
21
+ import torch
22
+
23
+ torch.manual_seed(seed)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed_all(seed)
26
+ except Exception:
27
+ pass
28
+
29
+
30
+ def ensure_dir(path: str | Path) -> Path:
31
+ """Create a directory if needed and return it as a ``Path``."""
32
+
33
+ resolved = Path(path)
34
+ resolved.mkdir(parents=True, exist_ok=True)
35
+ return resolved
36
+
37
+
38
+ def save_json(payload: Any, path: str | Path) -> None:
39
+ """Write JSON with stable indentation."""
40
+
41
+ Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8")
42
+
43
+
44
+ def load_json(path: str | Path) -> Any:
45
+ """Load JSON from disk."""
46
+
47
+ return json.loads(Path(path).read_text(encoding="utf-8"))
48
+
49
+
50
+ def extract_json_object(text: str) -> dict[str, Any]:
51
+ """Extract the first JSON object from model output."""
52
+
53
+ start = text.find("{")
54
+ end = text.rfind("}")
55
+ if start == -1 or end == -1 or end < start:
56
+ raise ValueError("Model output did not contain a JSON object.")
57
+ return json.loads(text[start : end + 1])
58
+
59
+
60
+ def moving_average(values: Sequence[float], window: int = 10) -> list[float]:
61
+ """Compute a simple moving average."""
62
+
63
+ if not values:
64
+ return []
65
+ window = max(1, int(window))
66
+ averaged: list[float] = []
67
+ for idx in range(len(values)):
68
+ start = max(0, idx - window + 1)
69
+ chunk = values[start : idx + 1]
70
+ averaged.append(sum(chunk) / len(chunk))
71
+ return averaged
72
+
73
+
74
+ def approximate_token_count(text: str) -> int:
75
+ """Cheap token estimate that works without a tokenizer."""
76
+
77
+ stripped = text.strip()
78
+ if not stripped:
79
+ return 0
80
+ return max(1, int(len(stripped.split()) * 1.3))
81
+
82
+
83
+ def flatten_dict_items(mapping: dict[str, Any], prefix: str = "") -> Iterable[tuple[str, Any]]:
84
+ """Flatten nested dictionaries for logging."""
85
+
86
+ for key, value in mapping.items():
87
+ full_key = f"{prefix}.{key}" if prefix else key
88
+ if isinstance(value, dict):
89
+ yield from flatten_dict_items(value, prefix=full_key)
90
+ else:
91
+ yield full_key, value