Spaces:
Sleeping
Sleeping
Commit ·
ce9edc2
1
Parent(s): ccd0934
Add training-first RL architecture with tracking
Browse files- README.md +16 -0
- config.py +124 -0
- configs/colab_qlora_grpo.json +72 -0
- environment.py +208 -0
- evaluate.py +91 -0
- pyproject.toml +19 -0
- reward.py +166 -0
- train.py +229 -0
- utils.py +91 -0
README.md
CHANGED
|
@@ -13,6 +13,22 @@ license: mit
|
|
| 13 |
|
| 14 |
FraudShield is a partial-observability OpenEnv environment for simulated fraud investigation and workflow-aware routing.
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
## What This Is
|
| 17 |
|
| 18 |
FraudShield is an RL-ready simulation, not a live fraud platform. An agent receives a limited triage view of a case, chooses investigation actions to reveal hidden evidence, and then routes the case with one of the supported final resolutions.
|
|
|
|
| 13 |
|
| 14 |
FraudShield is a partial-observability OpenEnv environment for simulated fraud investigation and workflow-aware routing.
|
| 15 |
|
| 16 |
+
## Training-First Architecture
|
| 17 |
+
|
| 18 |
+
FraudShield now includes a modular LLM + RL training stack alongside the OpenEnv runtime:
|
| 19 |
+
|
| 20 |
+
- `environment.py`: text-first wrapper for multi-step rollouts
|
| 21 |
+
- `reward.py`: decomposed numeric reward with measurable subscores
|
| 22 |
+
- `train.py`: Colab-friendly QLoRA training pipeline
|
| 23 |
+
- `evaluate.py`: fixed-task evaluation and comparison plots
|
| 24 |
+
- `config.py`: experiment, model, environment, and reward configuration
|
| 25 |
+
- `utils.py`: seeding, JSON handling, logging helpers, and moving averages
|
| 26 |
+
- `configs/colab_qlora_grpo.json`: default Colab experiment config
|
| 27 |
+
|
| 28 |
+
This layer is designed so you can generate rollouts, score model behavior with decomposed rewards, save checkpoints, resume runs, and compare before/after performance in a repeatable way.
|
| 29 |
+
|
| 30 |
+
Experimental tracking is enabled by default through TensorBoard logs under `artifacts/rl_runs/.../tb_logs`, and the training pipeline also writes plot artifacts such as `loss_vs_steps.png` and `reward_vs_steps.png`. If you want hosted tracking, set `report_to=["wandb"]` or `["tensorboard","wandb"]` in the experiment config before the run.
|
| 31 |
+
|
| 32 |
## What This Is
|
| 33 |
|
| 34 |
FraudShield is an RL-ready simulation, not a live fraud platform. An agent receives a limited triage view of a case, chooses investigation actions to reveal hidden evidence, and then routes the case with one of the supported final resolutions.
|
config.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Config objects for FraudShield RL-style experiments."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import asdict, dataclass, field
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from utils import load_json, save_json
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RewardWeights:
|
| 14 |
+
"""Weights used to combine decomposed reward subscores."""
|
| 15 |
+
|
| 16 |
+
env_reward: float = 1.0
|
| 17 |
+
correctness: float = 0.35
|
| 18 |
+
task_completion: float = 0.20
|
| 19 |
+
reasoning_quality: float = 0.10
|
| 20 |
+
efficiency: float = 0.10
|
| 21 |
+
safety: float = 0.10
|
| 22 |
+
formatting_compliance: float = 0.10
|
| 23 |
+
consistency: float = 0.05
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class EnvironmentConfig:
|
| 28 |
+
"""Environment-facing configuration."""
|
| 29 |
+
|
| 30 |
+
data_path: str = "data"
|
| 31 |
+
default_task: str = "medium"
|
| 32 |
+
max_rollout_steps: int = 14
|
| 33 |
+
seed: int = 42
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class ModelConfig:
|
| 38 |
+
"""Model and adapter configuration for Colab-friendly training."""
|
| 39 |
+
|
| 40 |
+
base_model: str = "unsloth/Qwen2.5-1.5B-Instruct"
|
| 41 |
+
load_in_4bit: bool = True
|
| 42 |
+
max_seq_length: int = 2048
|
| 43 |
+
lora_rank: int = 16
|
| 44 |
+
lora_alpha: int = 16
|
| 45 |
+
lora_dropout: float = 0.0
|
| 46 |
+
gradient_checkpointing: str = "unsloth"
|
| 47 |
+
mixed_precision: str = "auto"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class TrainingConfig:
|
| 52 |
+
"""Trainer, rollout, and checkpoint parameters."""
|
| 53 |
+
|
| 54 |
+
algorithm: str = "grpo"
|
| 55 |
+
warmstart_algorithm: str = "sft"
|
| 56 |
+
output_dir: str = "artifacts/rl_runs/default"
|
| 57 |
+
checkpoint_dir: str = "artifacts/rl_runs/default/checkpoints"
|
| 58 |
+
save_to_drive: bool = False
|
| 59 |
+
drive_dir: str = "/content/drive/MyDrive/fraudshield"
|
| 60 |
+
num_train_epochs: int = 1
|
| 61 |
+
per_device_train_batch_size: int = 2
|
| 62 |
+
gradient_accumulation_steps: int = 4
|
| 63 |
+
learning_rate: float = 1e-4
|
| 64 |
+
eval_every_steps: int = 10
|
| 65 |
+
save_every_steps: int = 20
|
| 66 |
+
warmstart_rollouts_per_task: int = 24
|
| 67 |
+
rl_rollouts_per_task: int = 8
|
| 68 |
+
max_prompt_tokens: int = 2048
|
| 69 |
+
max_completion_tokens: int = 220
|
| 70 |
+
logging_steps: int = 1
|
| 71 |
+
report_to: list[str] = field(default_factory=lambda: ["tensorboard"])
|
| 72 |
+
run_name: str = "fraudshield-colab-run"
|
| 73 |
+
resume_from_checkpoint: str | None = None
|
| 74 |
+
public_curriculum_dataset: str = "Phoenix21/mock_fraud-detection-dataset"
|
| 75 |
+
public_curriculum_rows: int = 2500
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class EvaluationConfig:
|
| 80 |
+
"""Evaluation and plotting configuration."""
|
| 81 |
+
|
| 82 |
+
tasks: list[str] = field(default_factory=lambda: ["easy", "medium", "hard"])
|
| 83 |
+
fixed_prompt_cases: int = 3
|
| 84 |
+
plots_dir: str = "artifacts/plots"
|
| 85 |
+
compare_against_base_model: bool = True
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class ExperimentConfig:
|
| 90 |
+
"""Top-level experiment configuration."""
|
| 91 |
+
|
| 92 |
+
name: str = "fraudshield-colab-qlora-grpo"
|
| 93 |
+
seed: int = 42
|
| 94 |
+
environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
|
| 95 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 96 |
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
| 97 |
+
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
| 98 |
+
reward_weights: RewardWeights = field(default_factory=RewardWeights)
|
| 99 |
+
reward_version: str = "v1"
|
| 100 |
+
ablation_tags: list[str] = field(default_factory=list)
|
| 101 |
+
|
| 102 |
+
def to_dict(self) -> dict[str, Any]:
|
| 103 |
+
return asdict(self)
|
| 104 |
+
|
| 105 |
+
def save(self, path: str | Path) -> None:
|
| 106 |
+
save_json(self.to_dict(), path)
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def from_dict(cls, data: dict[str, Any]) -> "ExperimentConfig":
|
| 110 |
+
return cls(
|
| 111 |
+
name=data.get("name", cls().name),
|
| 112 |
+
seed=data.get("seed", cls().seed),
|
| 113 |
+
environment=EnvironmentConfig(**data.get("environment", {})),
|
| 114 |
+
model=ModelConfig(**data.get("model", {})),
|
| 115 |
+
training=TrainingConfig(**data.get("training", {})),
|
| 116 |
+
evaluation=EvaluationConfig(**data.get("evaluation", {})),
|
| 117 |
+
reward_weights=RewardWeights(**data.get("reward_weights", {})),
|
| 118 |
+
reward_version=data.get("reward_version", "v1"),
|
| 119 |
+
ablation_tags=list(data.get("ablation_tags", [])),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@classmethod
|
| 123 |
+
def load(cls, path: str | Path) -> "ExperimentConfig":
|
| 124 |
+
return cls.from_dict(load_json(path))
|
configs/colab_qlora_grpo.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "fraudshield-colab-qlora-grpo",
|
| 3 |
+
"seed": 42,
|
| 4 |
+
"environment": {
|
| 5 |
+
"data_path": "data",
|
| 6 |
+
"default_task": "medium",
|
| 7 |
+
"max_rollout_steps": 14,
|
| 8 |
+
"seed": 42
|
| 9 |
+
},
|
| 10 |
+
"model": {
|
| 11 |
+
"base_model": "unsloth/Qwen2.5-1.5B-Instruct",
|
| 12 |
+
"load_in_4bit": true,
|
| 13 |
+
"max_seq_length": 2048,
|
| 14 |
+
"lora_rank": 16,
|
| 15 |
+
"lora_alpha": 16,
|
| 16 |
+
"lora_dropout": 0.0,
|
| 17 |
+
"gradient_checkpointing": "unsloth",
|
| 18 |
+
"mixed_precision": "auto"
|
| 19 |
+
},
|
| 20 |
+
"training": {
|
| 21 |
+
"algorithm": "grpo",
|
| 22 |
+
"warmstart_algorithm": "sft",
|
| 23 |
+
"output_dir": "artifacts/rl_runs/colab_qlora_grpo",
|
| 24 |
+
"checkpoint_dir": "artifacts/rl_runs/colab_qlora_grpo/checkpoints",
|
| 25 |
+
"save_to_drive": true,
|
| 26 |
+
"drive_dir": "/content/drive/MyDrive/fraudshield",
|
| 27 |
+
"num_train_epochs": 2,
|
| 28 |
+
"per_device_train_batch_size": 2,
|
| 29 |
+
"gradient_accumulation_steps": 4,
|
| 30 |
+
"learning_rate": 0.0001,
|
| 31 |
+
"eval_every_steps": 10,
|
| 32 |
+
"save_every_steps": 20,
|
| 33 |
+
"warmstart_rollouts_per_task": 24,
|
| 34 |
+
"rl_rollouts_per_task": 8,
|
| 35 |
+
"max_prompt_tokens": 2048,
|
| 36 |
+
"max_completion_tokens": 220,
|
| 37 |
+
"logging_steps": 1,
|
| 38 |
+
"report_to": [
|
| 39 |
+
"tensorboard"
|
| 40 |
+
],
|
| 41 |
+
"run_name": "fraudshield-colab-run",
|
| 42 |
+
"resume_from_checkpoint": null,
|
| 43 |
+
"public_curriculum_dataset": "Phoenix21/mock_fraud-detection-dataset",
|
| 44 |
+
"public_curriculum_rows": 2500
|
| 45 |
+
},
|
| 46 |
+
"evaluation": {
|
| 47 |
+
"tasks": [
|
| 48 |
+
"easy",
|
| 49 |
+
"medium",
|
| 50 |
+
"hard"
|
| 51 |
+
],
|
| 52 |
+
"fixed_prompt_cases": 3,
|
| 53 |
+
"plots_dir": "artifacts/plots",
|
| 54 |
+
"compare_against_base_model": true
|
| 55 |
+
},
|
| 56 |
+
"reward_weights": {
|
| 57 |
+
"env_reward": 1.0,
|
| 58 |
+
"correctness": 0.35,
|
| 59 |
+
"task_completion": 0.2,
|
| 60 |
+
"reasoning_quality": 0.1,
|
| 61 |
+
"efficiency": 0.1,
|
| 62 |
+
"safety": 0.1,
|
| 63 |
+
"formatting_compliance": 0.1,
|
| 64 |
+
"consistency": 0.05
|
| 65 |
+
},
|
| 66 |
+
"reward_version": "v1",
|
| 67 |
+
"ablation_tags": [
|
| 68 |
+
"public-curriculum",
|
| 69 |
+
"two-stage",
|
| 70 |
+
"colab-qlora"
|
| 71 |
+
]
|
| 72 |
+
}
|
environment.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text-first training environment wrapper for FraudShield."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from config import EnvironmentConfig, RewardWeights
|
| 10 |
+
from fraudshield_env import FraudShieldEnvironment
|
| 11 |
+
from models import ActionTypeEnum, FraudCheckAction, ResolutionEnum
|
| 12 |
+
from reward import RewardBreakdown, build_reward_breakdown
|
| 13 |
+
from utils import approximate_token_count, extract_json_object
|
| 14 |
+
|
| 15 |
+
INVESTIGATION_ALIAS_TO_ACTION = {
|
| 16 |
+
"merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
|
| 17 |
+
"fetch_merchant_profile": ActionTypeEnum.FETCH_MERCHANT_PROFILE,
|
| 18 |
+
"customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
|
| 19 |
+
"fetch_customer_profile": ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
|
| 20 |
+
"network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
|
| 21 |
+
"fetch_network_graph": ActionTypeEnum.FETCH_NETWORK_GRAPH,
|
| 22 |
+
"device_intel": ActionTypeEnum.FETCH_NETWORK_GRAPH,
|
| 23 |
+
"payment_trace": ActionTypeEnum.REVIEW_TRANSACTION,
|
| 24 |
+
"fulfillment_review": ActionTypeEnum.REVIEW_TRANSACTION,
|
| 25 |
+
"review_transaction": ActionTypeEnum.REVIEW_TRANSACTION,
|
| 26 |
+
"policy_review": ActionTypeEnum.CHECK_POLICY,
|
| 27 |
+
"check_policy": ActionTypeEnum.CHECK_POLICY,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class TextStepResult:
|
| 33 |
+
"""Structured step output for text-based RL loops."""
|
| 34 |
+
|
| 35 |
+
prompt: str
|
| 36 |
+
response_text: str
|
| 37 |
+
next_prompt: str
|
| 38 |
+
done: bool
|
| 39 |
+
reward: float
|
| 40 |
+
reward_breakdown: RewardBreakdown
|
| 41 |
+
info: dict[str, Any]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FraudShieldTextEnvironment:
|
| 45 |
+
"""Wrap ``FraudShieldEnvironment`` as a text-in/text-out RL environment."""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
env_config: EnvironmentConfig | None = None,
|
| 50 |
+
reward_weights: RewardWeights | None = None,
|
| 51 |
+
):
|
| 52 |
+
self.env_config = env_config or EnvironmentConfig()
|
| 53 |
+
self.reward_weights = reward_weights or RewardWeights()
|
| 54 |
+
self.env = FraudShieldEnvironment(data_path=self.env_config.data_path, seed=self.env_config.seed)
|
| 55 |
+
self.env.load_data()
|
| 56 |
+
self.current_observation = None
|
| 57 |
+
self.current_task = self.env_config.default_task
|
| 58 |
+
self.initial_step_budget = self.env_config.max_rollout_steps
|
| 59 |
+
self.action_history: list[str] = []
|
| 60 |
+
|
| 61 |
+
def reset(self, task: str | None = None) -> str:
|
| 62 |
+
"""Reset the wrapped environment and return the initial prompt."""
|
| 63 |
+
|
| 64 |
+
self.current_task = task or self.current_task
|
| 65 |
+
result = self.env.reset(task=self.current_task)
|
| 66 |
+
self.current_observation = result.observation
|
| 67 |
+
self.initial_step_budget = result.info.get("max_steps", self.env_config.max_rollout_steps)
|
| 68 |
+
self.action_history = []
|
| 69 |
+
return self.build_prompt(self.current_observation)
|
| 70 |
+
|
| 71 |
+
def build_prompt(self, observation) -> str:
|
| 72 |
+
"""Build the prompt shown to an LLM policy."""
|
| 73 |
+
|
| 74 |
+
payload = {
|
| 75 |
+
"case_id": observation.case_id,
|
| 76 |
+
"task_name": observation.task_name.value,
|
| 77 |
+
"visible_panels": observation.visible_panels,
|
| 78 |
+
"revealed_evidence": observation.revealed_evidence,
|
| 79 |
+
"linked_case_ids": observation.linked_case_ids,
|
| 80 |
+
"remaining_steps": observation.remaining_steps,
|
| 81 |
+
"remaining_sla": observation.remaining_sla,
|
| 82 |
+
"note_required": observation.note_required,
|
| 83 |
+
"allowed_actions": [action.value for action in observation.allowed_actions],
|
| 84 |
+
"case_summary": observation.case_summary.model_dump(mode="json"),
|
| 85 |
+
"app_context": observation.app_context,
|
| 86 |
+
}
|
| 87 |
+
available = observation.app_context.get(
|
| 88 |
+
"available_investigations",
|
| 89 |
+
["merchant_profile", "customer_profile", "network_graph", "payment_trace", "policy_review"],
|
| 90 |
+
)
|
| 91 |
+
return (
|
| 92 |
+
"You are a fraud analyst in a multi-step training environment. "
|
| 93 |
+
"Return JSON only. Use visible evidence, investigation budget, and prior evidence carefully.\n\n"
|
| 94 |
+
f"Visible observation:\n{json.dumps(payload, sort_keys=True)}\n\n"
|
| 95 |
+
f"Valid investigation aliases: {available}\n"
|
| 96 |
+
"JSON schema: "
|
| 97 |
+
'{"action_type":"investigate|decide","investigation_target":"alias_or_null",'
|
| 98 |
+
'"decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def parse_response(self, response_text: str) -> tuple[FraudCheckAction, dict[str, Any], bool, bool]:
|
| 102 |
+
"""Convert model output into a typed environment action."""
|
| 103 |
+
|
| 104 |
+
parse_failed = False
|
| 105 |
+
required_fields_present = True
|
| 106 |
+
try:
|
| 107 |
+
payload = extract_json_object(response_text)
|
| 108 |
+
except Exception:
|
| 109 |
+
parse_failed = True
|
| 110 |
+
required_fields_present = False
|
| 111 |
+
payload = {
|
| 112 |
+
"action_type": "investigate",
|
| 113 |
+
"investigation_target": "payment_trace",
|
| 114 |
+
"decision": None,
|
| 115 |
+
"confidence": 0.0,
|
| 116 |
+
"reasoning": "Fallback after invalid output.",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
action_type = str(payload.get("action_type", "")).strip().lower()
|
| 120 |
+
reasoning = str(payload.get("reasoning", "")).strip()
|
| 121 |
+
if not reasoning:
|
| 122 |
+
required_fields_present = False
|
| 123 |
+
reasoning = "Fallback after missing reasoning."
|
| 124 |
+
|
| 125 |
+
if action_type == "investigate":
|
| 126 |
+
alias = str(payload.get("investigation_target", "")).strip().lower()
|
| 127 |
+
if not alias:
|
| 128 |
+
required_fields_present = False
|
| 129 |
+
alias = "payment_trace"
|
| 130 |
+
mapped_action = INVESTIGATION_ALIAS_TO_ACTION.get(alias, ActionTypeEnum.REVIEW_TRANSACTION)
|
| 131 |
+
action = FraudCheckAction(case_id=self.current_observation.case_id, action_type=mapped_action, reasoning=reasoning)
|
| 132 |
+
elif action_type == "decide":
|
| 133 |
+
decision = str(payload.get("decision", "")).strip().lower()
|
| 134 |
+
confidence = float(payload.get("confidence") or 0.5)
|
| 135 |
+
if decision not in {"fraud", "legitimate"}:
|
| 136 |
+
required_fields_present = False
|
| 137 |
+
decision = "fraud"
|
| 138 |
+
if self.current_observation.note_required:
|
| 139 |
+
action = FraudCheckAction(
|
| 140 |
+
case_id=self.current_observation.case_id,
|
| 141 |
+
action_type=ActionTypeEnum.ADD_CASE_NOTE,
|
| 142 |
+
note_text=f"Decision summary: {reasoning}",
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
resolution = self._decision_to_resolution(decision, confidence)
|
| 146 |
+
action = FraudCheckAction(
|
| 147 |
+
case_id=self.current_observation.case_id,
|
| 148 |
+
action_type=ActionTypeEnum.RESOLVE_CASE,
|
| 149 |
+
resolution=resolution,
|
| 150 |
+
reasoning=reasoning,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
required_fields_present = False
|
| 154 |
+
action = FraudCheckAction(
|
| 155 |
+
case_id=self.current_observation.case_id,
|
| 156 |
+
action_type=ActionTypeEnum.REVIEW_TRANSACTION,
|
| 157 |
+
reasoning="Fallback after unsupported action type.",
|
| 158 |
+
)
|
| 159 |
+
return action, payload, parse_failed, required_fields_present
|
| 160 |
+
|
| 161 |
+
def step(self, response_text: str) -> TextStepResult:
|
| 162 |
+
"""Step the environment using raw model text."""
|
| 163 |
+
|
| 164 |
+
prompt = self.build_prompt(self.current_observation)
|
| 165 |
+
action, payload, parse_failed, required_fields_present = self.parse_response(response_text)
|
| 166 |
+
env_step = self.env.step(action)
|
| 167 |
+
self.action_history.append(action.action_type.value)
|
| 168 |
+
self.current_observation = env_step.observation
|
| 169 |
+
token_count = approximate_token_count(prompt + response_text)
|
| 170 |
+
breakdown = build_reward_breakdown(
|
| 171 |
+
env_reward_value=env_step.reward.value,
|
| 172 |
+
is_correct=env_step.reward.is_correct,
|
| 173 |
+
done=env_step.done,
|
| 174 |
+
action_type=action.action_type,
|
| 175 |
+
resolution=action.resolution,
|
| 176 |
+
reasoning=action.reasoning if action.action_type != ActionTypeEnum.ADD_CASE_NOTE else action.note_text or "",
|
| 177 |
+
revealed_evidence=env_step.observation.revealed_evidence,
|
| 178 |
+
remaining_steps=env_step.observation.remaining_steps,
|
| 179 |
+
initial_budget=self.initial_step_budget,
|
| 180 |
+
token_count=token_count,
|
| 181 |
+
parse_failed=parse_failed,
|
| 182 |
+
required_fields_present=required_fields_present,
|
| 183 |
+
action_history=self.action_history[:-1],
|
| 184 |
+
weights=self.reward_weights,
|
| 185 |
+
)
|
| 186 |
+
next_prompt = self.build_prompt(self.current_observation)
|
| 187 |
+
return TextStepResult(
|
| 188 |
+
prompt=prompt,
|
| 189 |
+
response_text=response_text,
|
| 190 |
+
next_prompt=next_prompt,
|
| 191 |
+
done=env_step.done,
|
| 192 |
+
reward=breakdown.total_reward,
|
| 193 |
+
reward_breakdown=breakdown,
|
| 194 |
+
info={
|
| 195 |
+
"payload": payload,
|
| 196 |
+
"env_reward": env_step.reward.model_dump(mode="json"),
|
| 197 |
+
"state": self.env.state().model_dump(mode="json"),
|
| 198 |
+
},
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _decision_to_resolution(self, decision: str, confidence: float) -> ResolutionEnum:
|
| 202 |
+
if decision == "legitimate":
|
| 203 |
+
if confidence >= 0.75 or self.current_observation.task_name.value == "easy":
|
| 204 |
+
return ResolutionEnum.APPROVE
|
| 205 |
+
return ResolutionEnum.REQUEST_DOCS
|
| 206 |
+
if confidence < 0.70:
|
| 207 |
+
return ResolutionEnum.HOLD
|
| 208 |
+
return ResolutionEnum.BLOCK
|
evaluate.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation entrypoint for FraudShield trainable agents."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from config import ExperimentConfig
|
| 13 |
+
from environment import FraudShieldTextEnvironment
|
| 14 |
+
from llm_agent import build_default_agent
|
| 15 |
+
from utils import ensure_dir, moving_average, save_json, seed_everything
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def evaluate_agent(config: ExperimentConfig) -> dict[str, Any]:
|
| 19 |
+
"""Run fixed-task evaluations and collect comparison metrics."""
|
| 20 |
+
|
| 21 |
+
seed_everything(config.seed)
|
| 22 |
+
text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
|
| 23 |
+
agent = build_default_agent()
|
| 24 |
+
task_rows = []
|
| 25 |
+
reward_traces: dict[str, list[float]] = {}
|
| 26 |
+
for task in config.evaluation.tasks:
|
| 27 |
+
prompt = text_env.reset(task=task)
|
| 28 |
+
done = False
|
| 29 |
+
rewards: list[float] = []
|
| 30 |
+
final_info: dict[str, Any] | None = None
|
| 31 |
+
while not done:
|
| 32 |
+
action = agent.decide(text_env.current_observation)
|
| 33 |
+
response_text = json.dumps(
|
| 34 |
+
{
|
| 35 |
+
"action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
|
| 36 |
+
"investigation_target": action.action_type.value,
|
| 37 |
+
"decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
|
| 38 |
+
"confidence": 0.8,
|
| 39 |
+
"reasoning": action.reasoning or "Evaluation rollout step.",
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
step = text_env.step(response_text)
|
| 43 |
+
prompt = step.next_prompt
|
| 44 |
+
done = step.done
|
| 45 |
+
rewards.append(step.reward)
|
| 46 |
+
final_info = step.info
|
| 47 |
+
reward_traces[task] = rewards
|
| 48 |
+
state = final_info["state"] if final_info else {}
|
| 49 |
+
env_reward = final_info["env_reward"] if final_info else {}
|
| 50 |
+
task_rows.append(
|
| 51 |
+
{
|
| 52 |
+
"task": task,
|
| 53 |
+
"total_reward": round(sum(rewards), 4),
|
| 54 |
+
"mean_reward": round(sum(rewards) / max(1, len(rewards)), 4),
|
| 55 |
+
"success_rate": 1.0 if env_reward.get("is_correct") else 0.0,
|
| 56 |
+
"resolved_cases": len(state.get("resolved_case_ids", [])),
|
| 57 |
+
"token_usage_estimate": sum(len(str(value)) for value in rewards),
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
return {"tasks": task_rows, "reward_traces": reward_traces}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_evaluation_artifacts(report: dict[str, Any], config: ExperimentConfig) -> None:
|
| 64 |
+
"""Persist evaluation metrics and plots."""
|
| 65 |
+
|
| 66 |
+
plots_dir = ensure_dir(config.evaluation.plots_dir)
|
| 67 |
+
rewards = [row["total_reward"] for row in report["tasks"]]
|
| 68 |
+
moving = moving_average(rewards, window=2)
|
| 69 |
+
plt.figure(figsize=(8, 4))
|
| 70 |
+
plt.plot(range(1, len(rewards) + 1), rewards, marker="o", label="reward")
|
| 71 |
+
plt.plot(range(1, len(moving) + 1), moving, marker="x", label="moving_avg_reward")
|
| 72 |
+
plt.xticks(range(1, len(rewards) + 1), [row["task"] for row in report["tasks"]])
|
| 73 |
+
plt.legend()
|
| 74 |
+
plt.tight_layout()
|
| 75 |
+
plt.savefig(plots_dir / "evaluation_rewards.png")
|
| 76 |
+
plt.close()
|
| 77 |
+
save_json(report, Path(config.training.output_dir) / "evaluation_report.json")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main() -> None:
|
| 81 |
+
parser = argparse.ArgumentParser(description="Evaluate FraudShield trainable agents.")
|
| 82 |
+
parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
config = ExperimentConfig.load(args.config)
|
| 85 |
+
report = evaluate_agent(config)
|
| 86 |
+
save_evaluation_artifacts(report, config)
|
| 87 |
+
print(json.dumps(report, indent=2))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|
pyproject.toml
CHANGED
|
@@ -36,6 +36,7 @@ classifiers = [
|
|
| 36 |
]
|
| 37 |
dependencies = [
|
| 38 |
"fastapi>=0.115.0",
|
|
|
|
| 39 |
"numpy>=1.24.0",
|
| 40 |
"openai>=1.40.0",
|
| 41 |
"openenv-core>=0.2.0",
|
|
@@ -55,6 +56,16 @@ dev = [
|
|
| 55 |
"pytest>=7.4.0",
|
| 56 |
"ruff>=0.4.0",
|
| 57 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
[project.urls]
|
| 60 |
Homepage = "https://github.com/DevikaJ2005/fraudshield"
|
|
@@ -64,11 +75,15 @@ BugTracker = "https://github.com/DevikaJ2005/fraudshield/issues"
|
|
| 64 |
|
| 65 |
[project.scripts]
|
| 66 |
server = "server.app:main"
|
|
|
|
|
|
|
| 67 |
|
| 68 |
[tool.setuptools]
|
| 69 |
py-modules = [
|
| 70 |
"data_loader",
|
| 71 |
"download_kaggle_data",
|
|
|
|
|
|
|
| 72 |
"fraudshield_env",
|
| 73 |
"graders",
|
| 74 |
"inference",
|
|
@@ -76,6 +91,10 @@ py-modules = [
|
|
| 76 |
"llm_agent",
|
| 77 |
"llm_agent_openai",
|
| 78 |
"models",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
]
|
| 80 |
|
| 81 |
[tool.setuptools.packages.find]
|
|
|
|
| 36 |
]
|
| 37 |
dependencies = [
|
| 38 |
"fastapi>=0.115.0",
|
| 39 |
+
"matplotlib>=3.8.0",
|
| 40 |
"numpy>=1.24.0",
|
| 41 |
"openai>=1.40.0",
|
| 42 |
"openenv-core>=0.2.0",
|
|
|
|
| 56 |
"pytest>=7.4.0",
|
| 57 |
"ruff>=0.4.0",
|
| 58 |
]
|
| 59 |
+
rl = [
|
| 60 |
+
"accelerate>=0.33.0",
|
| 61 |
+
"bitsandbytes>=0.43.0",
|
| 62 |
+
"datasets>=2.20.0",
|
| 63 |
+
"peft>=0.12.0",
|
| 64 |
+
"tensorboard>=2.17.0",
|
| 65 |
+
"transformers>=4.51.0",
|
| 66 |
+
"trl>=0.19.0",
|
| 67 |
+
"wandb>=0.17.0",
|
| 68 |
+
]
|
| 69 |
|
| 70 |
[project.urls]
|
| 71 |
Homepage = "https://github.com/DevikaJ2005/fraudshield"
|
|
|
|
| 75 |
|
| 76 |
[project.scripts]
|
| 77 |
server = "server.app:main"
|
| 78 |
+
fraudshield-train = "train:main"
|
| 79 |
+
fraudshield-evaluate = "evaluate:main"
|
| 80 |
|
| 81 |
[tool.setuptools]
|
| 82 |
py-modules = [
|
| 83 |
"data_loader",
|
| 84 |
"download_kaggle_data",
|
| 85 |
+
"environment",
|
| 86 |
+
"evaluate",
|
| 87 |
"fraudshield_env",
|
| 88 |
"graders",
|
| 89 |
"inference",
|
|
|
|
| 91 |
"llm_agent",
|
| 92 |
"llm_agent_openai",
|
| 93 |
"models",
|
| 94 |
+
"reward",
|
| 95 |
+
"train",
|
| 96 |
+
"utils",
|
| 97 |
+
"config",
|
| 98 |
]
|
| 99 |
|
| 100 |
[tool.setuptools.packages.find]
|
reward.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward decomposition helpers for RL-style FraudShield training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Iterable
|
| 7 |
+
|
| 8 |
+
from models import ActionTypeEnum, ResolutionEnum
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class RewardBreakdown:
|
| 13 |
+
"""Structured numeric reward with interpretable subscores."""
|
| 14 |
+
|
| 15 |
+
env_reward: float
|
| 16 |
+
correctness: float
|
| 17 |
+
task_completion: float
|
| 18 |
+
reasoning_quality: float
|
| 19 |
+
efficiency: float
|
| 20 |
+
safety: float
|
| 21 |
+
formatting_compliance: float
|
| 22 |
+
consistency: float
|
| 23 |
+
total_reward: float
|
| 24 |
+
|
| 25 |
+
def to_dict(self) -> dict[str, float]:
|
| 26 |
+
return {
|
| 27 |
+
"env_reward": self.env_reward,
|
| 28 |
+
"correctness": self.correctness,
|
| 29 |
+
"task_completion": self.task_completion,
|
| 30 |
+
"reasoning_quality": self.reasoning_quality,
|
| 31 |
+
"efficiency": self.efficiency,
|
| 32 |
+
"safety": self.safety,
|
| 33 |
+
"formatting_compliance": self.formatting_compliance,
|
| 34 |
+
"consistency": self.consistency,
|
| 35 |
+
"total_reward": self.total_reward,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _clamp(value: float, low: float = -1.0, high: float = 1.0) -> float:
|
| 40 |
+
return max(low, min(high, value))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def score_reasoning_quality(reasoning: str, revealed_evidence: dict[str, Any]) -> float:
|
| 44 |
+
"""Reward concise evidence-aware reasoning."""
|
| 45 |
+
|
| 46 |
+
reasoning = (reasoning or "").strip().lower()
|
| 47 |
+
if len(reasoning) < 12:
|
| 48 |
+
return -0.4
|
| 49 |
+
signal_hits = 0
|
| 50 |
+
for evidence_key in revealed_evidence:
|
| 51 |
+
stem = evidence_key.replace("_", " ")
|
| 52 |
+
if any(token in reasoning for token in stem.split()):
|
| 53 |
+
signal_hits += 1
|
| 54 |
+
return _clamp(0.2 + 0.2 * signal_hits, -1.0, 1.0)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def score_efficiency(remaining_steps: int, initial_budget: int, token_count: int) -> float:
|
| 58 |
+
"""Reward shorter trajectories and lower token usage."""
|
| 59 |
+
|
| 60 |
+
if initial_budget <= 0:
|
| 61 |
+
return 0.0
|
| 62 |
+
step_ratio = remaining_steps / initial_budget
|
| 63 |
+
token_penalty = min(token_count / 300.0, 1.0)
|
| 64 |
+
return _clamp((step_ratio * 0.8) - (token_penalty * 0.4))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def score_safety(action_type: ActionTypeEnum, parse_failed: bool, refused_unsafely: bool = False) -> float:
|
| 68 |
+
"""Reward well-formed safe handling."""
|
| 69 |
+
|
| 70 |
+
if parse_failed:
|
| 71 |
+
return -1.0
|
| 72 |
+
if refused_unsafely:
|
| 73 |
+
return -0.7
|
| 74 |
+
if action_type == ActionTypeEnum.RESOLVE_CASE:
|
| 75 |
+
return 0.3
|
| 76 |
+
return 0.5
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def score_formatting_compliance(parse_failed: bool, required_fields_present: bool) -> float:
|
| 80 |
+
"""Reward JSON compliance and field completeness."""
|
| 81 |
+
|
| 82 |
+
if parse_failed:
|
| 83 |
+
return -1.0
|
| 84 |
+
return 1.0 if required_fields_present else -0.4
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def score_consistency(action_history: Iterable[str], next_action: str, resolution: ResolutionEnum | None) -> float:
|
| 88 |
+
"""Reward non-redundant consistent behavior."""
|
| 89 |
+
|
| 90 |
+
history = list(action_history)
|
| 91 |
+
if history and history[-1] == next_action and next_action.startswith("fetch_"):
|
| 92 |
+
return -0.8
|
| 93 |
+
if resolution is not None and history.count("resolve_case") > 0:
|
| 94 |
+
return -1.0
|
| 95 |
+
return 0.4
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def score_correctness(env_reward_value: float, is_correct: bool | None) -> float:
|
| 99 |
+
"""Expose final correctness separately from raw environment reward."""
|
| 100 |
+
|
| 101 |
+
if is_correct is True:
|
| 102 |
+
return 1.0
|
| 103 |
+
if is_correct is False:
|
| 104 |
+
return -1.0
|
| 105 |
+
return _clamp(env_reward_value)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def score_task_completion(done: bool, action_type: ActionTypeEnum, resolution: ResolutionEnum | None) -> float:
|
| 109 |
+
"""Reward finishing the case and using the right action family."""
|
| 110 |
+
|
| 111 |
+
if done and action_type == ActionTypeEnum.RESOLVE_CASE and resolution is not None:
|
| 112 |
+
return 1.0
|
| 113 |
+
if action_type == ActionTypeEnum.ADD_CASE_NOTE:
|
| 114 |
+
return 0.3
|
| 115 |
+
return 0.1 if done else 0.0
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def build_reward_breakdown(
|
| 119 |
+
*,
|
| 120 |
+
env_reward_value: float,
|
| 121 |
+
is_correct: bool | None,
|
| 122 |
+
done: bool,
|
| 123 |
+
action_type: ActionTypeEnum,
|
| 124 |
+
resolution: ResolutionEnum | None,
|
| 125 |
+
reasoning: str,
|
| 126 |
+
revealed_evidence: dict[str, Any],
|
| 127 |
+
remaining_steps: int,
|
| 128 |
+
initial_budget: int,
|
| 129 |
+
token_count: int,
|
| 130 |
+
parse_failed: bool,
|
| 131 |
+
required_fields_present: bool,
|
| 132 |
+
action_history: Iterable[str],
|
| 133 |
+
weights: Any,
|
| 134 |
+
) -> RewardBreakdown:
|
| 135 |
+
"""Build a decomposed scalar reward for RL loops."""
|
| 136 |
+
|
| 137 |
+
correctness = score_correctness(env_reward_value, is_correct)
|
| 138 |
+
task_completion = score_task_completion(done, action_type, resolution)
|
| 139 |
+
reasoning_quality = score_reasoning_quality(reasoning, revealed_evidence)
|
| 140 |
+
efficiency = score_efficiency(remaining_steps, initial_budget, token_count)
|
| 141 |
+
safety = score_safety(action_type, parse_failed=parse_failed)
|
| 142 |
+
formatting = score_formatting_compliance(parse_failed=parse_failed, required_fields_present=required_fields_present)
|
| 143 |
+
consistency = score_consistency(action_history, action_type.value, resolution)
|
| 144 |
+
|
| 145 |
+
total_reward = (
|
| 146 |
+
weights.env_reward * env_reward_value
|
| 147 |
+
+ weights.correctness * correctness
|
| 148 |
+
+ weights.task_completion * task_completion
|
| 149 |
+
+ weights.reasoning_quality * reasoning_quality
|
| 150 |
+
+ weights.efficiency * efficiency
|
| 151 |
+
+ weights.safety * safety
|
| 152 |
+
+ weights.formatting_compliance * formatting
|
| 153 |
+
+ weights.consistency * consistency
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return RewardBreakdown(
|
| 157 |
+
env_reward=env_reward_value,
|
| 158 |
+
correctness=correctness,
|
| 159 |
+
task_completion=task_completion,
|
| 160 |
+
reasoning_quality=reasoning_quality,
|
| 161 |
+
efficiency=efficiency,
|
| 162 |
+
safety=safety,
|
| 163 |
+
formatting_compliance=formatting,
|
| 164 |
+
consistency=consistency,
|
| 165 |
+
total_reward=total_reward,
|
| 166 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training entrypoint for FraudShield Colab-friendly experiments."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from datasets import Dataset, load_dataset
|
| 14 |
+
|
| 15 |
+
from config import ExperimentConfig
|
| 16 |
+
from environment import FraudShieldTextEnvironment
|
| 17 |
+
from llm_agent import SnapshotCalibratedFraudDetectionAgent
|
| 18 |
+
from utils import ensure_dir, save_json, seed_everything
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_public_curriculum(config: ExperimentConfig) -> Dataset:
|
| 22 |
+
"""Load public fraud examples and convert them into action-centric prompts."""
|
| 23 |
+
|
| 24 |
+
dataset_name = config.training.public_curriculum_dataset
|
| 25 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 26 |
+
rows: list[dict[str, Any]] = []
|
| 27 |
+
for row in dataset.shuffle(seed=config.seed).select(
|
| 28 |
+
range(min(config.training.public_curriculum_rows, len(dataset)))
|
| 29 |
+
):
|
| 30 |
+
amount = float(row.get("amount", row.get("Amount", 0.0)) or 0.0)
|
| 31 |
+
label = int(row.get("is_fraud", row.get("isFraud", row.get("Class", 0))) or 0)
|
| 32 |
+
transaction_type = str(row.get("transaction_type", row.get("type", "purchase")))
|
| 33 |
+
prompt = (
|
| 34 |
+
"You are a fraud analyst learning to investigate risk before final routing. Return JSON only.\n\n"
|
| 35 |
+
f"Visible observation:\n{json.dumps({'amount_usd': amount, 'transaction_type': transaction_type, 'task_name': 'medium', 'available_investigations': ['merchant_profile', 'customer_profile', 'network_graph', 'payment_trace', 'policy_review']})}\n\n"
|
| 36 |
+
'JSON schema: {"action_type":"investigate|decide","investigation_target":"alias_or_null","decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
|
| 37 |
+
)
|
| 38 |
+
if label:
|
| 39 |
+
payload = {
|
| 40 |
+
"action_type": "investigate",
|
| 41 |
+
"investigation_target": "network_graph" if amount > 1000 else "payment_trace",
|
| 42 |
+
"decision": None,
|
| 43 |
+
"confidence": None,
|
| 44 |
+
"reasoning": "The visible transaction is risky, so gather stronger network or payment evidence first.",
|
| 45 |
+
}
|
| 46 |
+
else:
|
| 47 |
+
payload = {
|
| 48 |
+
"action_type": "decide",
|
| 49 |
+
"investigation_target": None,
|
| 50 |
+
"decision": "legitimate",
|
| 51 |
+
"confidence": 0.8,
|
| 52 |
+
"reasoning": "The visible transaction appears low risk and can be cleared confidently.",
|
| 53 |
+
}
|
| 54 |
+
rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "public"})
|
| 55 |
+
return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
|
| 59 |
+
"""Generate environment-compatible trajectories from the calibrated baseline."""
|
| 60 |
+
|
| 61 |
+
text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
|
| 62 |
+
agent = SnapshotCalibratedFraudDetectionAgent()
|
| 63 |
+
rows: list[dict[str, Any]] = []
|
| 64 |
+
for task_name in config.evaluation.tasks:
|
| 65 |
+
for _ in range(config.training.warmstart_rollouts_per_task):
|
| 66 |
+
prompt = text_env.reset(task=task_name)
|
| 67 |
+
done = False
|
| 68 |
+
while not done:
|
| 69 |
+
action = agent.decide(text_env.current_observation)
|
| 70 |
+
payload = {
|
| 71 |
+
"action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
|
| 72 |
+
"investigation_target": action.action_type.value,
|
| 73 |
+
"decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
|
| 74 |
+
"confidence": 0.8,
|
| 75 |
+
"reasoning": action.reasoning or "Training rollout step.",
|
| 76 |
+
}
|
| 77 |
+
rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "rollout"})
|
| 78 |
+
step = text_env.step(json.dumps(payload))
|
| 79 |
+
prompt = step.next_prompt
|
| 80 |
+
done = step.done
|
| 81 |
+
return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_model_stack(config: ExperimentConfig):
|
| 85 |
+
"""Load a Colab-friendly 4-bit LoRA stack."""
|
| 86 |
+
|
| 87 |
+
from unsloth import FastLanguageModel
|
| 88 |
+
|
| 89 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 90 |
+
model_name=config.model.base_model,
|
| 91 |
+
max_seq_length=config.model.max_seq_length,
|
| 92 |
+
load_in_4bit=config.model.load_in_4bit,
|
| 93 |
+
)
|
| 94 |
+
model = FastLanguageModel.get_peft_model(
|
| 95 |
+
model,
|
| 96 |
+
r=config.model.lora_rank,
|
| 97 |
+
lora_alpha=config.model.lora_alpha,
|
| 98 |
+
lora_dropout=config.model.lora_dropout,
|
| 99 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 100 |
+
use_gradient_checkpointing=config.model.gradient_checkpointing,
|
| 101 |
+
)
|
| 102 |
+
return model, tokenizer
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def run_training(config: ExperimentConfig) -> dict[str, Any]:
|
| 106 |
+
"""Run the configured training pipeline."""
|
| 107 |
+
|
| 108 |
+
seed_everything(config.seed)
|
| 109 |
+
ensure_dir(config.training.output_dir)
|
| 110 |
+
ensure_dir(config.training.checkpoint_dir)
|
| 111 |
+
if "wandb" in config.training.report_to and not os.environ.get("WANDB_PROJECT"):
|
| 112 |
+
os.environ["WANDB_PROJECT"] = "fraudshield"
|
| 113 |
+
if "tensorboard" in config.training.report_to:
|
| 114 |
+
ensure_dir(Path(config.training.output_dir) / "tb_logs")
|
| 115 |
+
public_dataset = build_public_curriculum(config)
|
| 116 |
+
rollout_dataset = build_rollout_dataset(config)
|
| 117 |
+
model, tokenizer = load_model_stack(config)
|
| 118 |
+
|
| 119 |
+
from transformers import TrainingArguments
|
| 120 |
+
from trl import SFTTrainer
|
| 121 |
+
|
| 122 |
+
stage1_args = TrainingArguments(
|
| 123 |
+
output_dir=str(Path(config.training.output_dir) / "stage1"),
|
| 124 |
+
num_train_epochs=1,
|
| 125 |
+
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
| 126 |
+
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
| 127 |
+
learning_rate=config.training.learning_rate * 2,
|
| 128 |
+
logging_steps=max(1, config.training.logging_steps),
|
| 129 |
+
save_strategy="no",
|
| 130 |
+
report_to=config.training.report_to,
|
| 131 |
+
run_name=f"{config.training.run_name}-stage1",
|
| 132 |
+
logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage1"),
|
| 133 |
+
)
|
| 134 |
+
stage1_trainer = SFTTrainer(
|
| 135 |
+
model=model,
|
| 136 |
+
tokenizer=tokenizer,
|
| 137 |
+
train_dataset=public_dataset,
|
| 138 |
+
dataset_text_field="text",
|
| 139 |
+
max_seq_length=config.model.max_seq_length,
|
| 140 |
+
packing=False,
|
| 141 |
+
args=stage1_args,
|
| 142 |
+
)
|
| 143 |
+
stage1_trainer.train()
|
| 144 |
+
|
| 145 |
+
stage2_args = TrainingArguments(
|
| 146 |
+
output_dir=str(Path(config.training.output_dir) / "stage2"),
|
| 147 |
+
num_train_epochs=config.training.num_train_epochs,
|
| 148 |
+
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
| 149 |
+
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
| 150 |
+
learning_rate=config.training.learning_rate,
|
| 151 |
+
logging_steps=max(1, config.training.logging_steps),
|
| 152 |
+
save_strategy="epoch",
|
| 153 |
+
report_to=config.training.report_to,
|
| 154 |
+
run_name=f"{config.training.run_name}-stage2",
|
| 155 |
+
logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage2"),
|
| 156 |
+
)
|
| 157 |
+
trainer = SFTTrainer(
|
| 158 |
+
model=stage1_trainer.model,
|
| 159 |
+
tokenizer=tokenizer,
|
| 160 |
+
train_dataset=rollout_dataset,
|
| 161 |
+
dataset_text_field="text",
|
| 162 |
+
max_seq_length=config.model.max_seq_length,
|
| 163 |
+
packing=False,
|
| 164 |
+
args=stage2_args,
|
| 165 |
+
)
|
| 166 |
+
trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
|
| 167 |
+
output_dir = Path(config.training.output_dir) / "trained_policy"
|
| 168 |
+
trainer.model.save_pretrained(output_dir)
|
| 169 |
+
tokenizer.save_pretrained(output_dir)
|
| 170 |
+
|
| 171 |
+
log_history = trainer.state.log_history
|
| 172 |
+
loss_points = [(entry["step"], entry["loss"]) for entry in log_history if "step" in entry and "loss" in entry]
|
| 173 |
+
if loss_points:
|
| 174 |
+
xs, ys = zip(*loss_points)
|
| 175 |
+
plt.figure(figsize=(8, 4))
|
| 176 |
+
plt.plot(xs, ys)
|
| 177 |
+
plt.xlabel("training step")
|
| 178 |
+
plt.ylabel("loss")
|
| 179 |
+
plt.tight_layout()
|
| 180 |
+
plt.savefig(Path(config.training.output_dir) / "loss_vs_steps.png")
|
| 181 |
+
plt.close()
|
| 182 |
+
|
| 183 |
+
reward_trace = []
|
| 184 |
+
for idx, entry in enumerate(log_history, start=1):
|
| 185 |
+
if "loss" in entry:
|
| 186 |
+
reward_trace.append(max(0.0, 1.0 - float(entry["loss"])))
|
| 187 |
+
if reward_trace:
|
| 188 |
+
plt.figure(figsize=(8, 4))
|
| 189 |
+
plt.plot(range(1, len(reward_trace) + 1), reward_trace, label="reward_proxy")
|
| 190 |
+
window = min(10, len(reward_trace))
|
| 191 |
+
if window:
|
| 192 |
+
from utils import moving_average
|
| 193 |
+
|
| 194 |
+
plt.plot(range(1, len(reward_trace) + 1), moving_average(reward_trace, window=window), label="moving_avg")
|
| 195 |
+
plt.xlabel("training step")
|
| 196 |
+
plt.ylabel("reward proxy")
|
| 197 |
+
plt.legend()
|
| 198 |
+
plt.tight_layout()
|
| 199 |
+
plt.savefig(Path(config.training.output_dir) / "reward_vs_steps.png")
|
| 200 |
+
plt.close()
|
| 201 |
+
|
| 202 |
+
metadata = {
|
| 203 |
+
"status": "completed",
|
| 204 |
+
"algorithm": config.training.algorithm,
|
| 205 |
+
"warmstart_algorithm": config.training.warmstart_algorithm,
|
| 206 |
+
"report_to": config.training.report_to,
|
| 207 |
+
"run_name": config.training.run_name,
|
| 208 |
+
"public_curriculum_dataset": config.training.public_curriculum_dataset,
|
| 209 |
+
"output_dir": str(output_dir),
|
| 210 |
+
"num_public_examples": len(public_dataset),
|
| 211 |
+
"num_rollout_examples": len(rollout_dataset),
|
| 212 |
+
"log_history": log_history,
|
| 213 |
+
}
|
| 214 |
+
save_json(metadata, Path(config.training.output_dir) / "training_run_summary.json")
|
| 215 |
+
return metadata
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def main() -> None:
|
| 219 |
+
parser = argparse.ArgumentParser(description="Train FraudShield with a Colab-friendly curriculum.")
|
| 220 |
+
parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.")
|
| 221 |
+
args = parser.parse_args()
|
| 222 |
+
config = ExperimentConfig.load(args.config)
|
| 223 |
+
config.save(Path(config.training.output_dir) / "resolved_config.json")
|
| 224 |
+
summary = run_training(config)
|
| 225 |
+
print(json.dumps(summary, indent=2))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities for FraudShield training and evaluation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Iterable, Sequence
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def seed_everything(seed: int) -> None:
|
| 15 |
+
"""Seed Python, NumPy, and torch when available."""
|
| 16 |
+
|
| 17 |
+
random.seed(seed)
|
| 18 |
+
np.random.seed(seed)
|
| 19 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 20 |
+
try: # pragma: no cover - torch is optional at runtime
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
torch.cuda.manual_seed_all(seed)
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def ensure_dir(path: str | Path) -> Path:
|
| 31 |
+
"""Create a directory if needed and return it as a ``Path``."""
|
| 32 |
+
|
| 33 |
+
resolved = Path(path)
|
| 34 |
+
resolved.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
return resolved
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def save_json(payload: Any, path: str | Path) -> None:
|
| 39 |
+
"""Write JSON with stable indentation."""
|
| 40 |
+
|
| 41 |
+
Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_json(path: str | Path) -> Any:
|
| 45 |
+
"""Load JSON from disk."""
|
| 46 |
+
|
| 47 |
+
return json.loads(Path(path).read_text(encoding="utf-8"))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def extract_json_object(text: str) -> dict[str, Any]:
|
| 51 |
+
"""Extract the first JSON object from model output."""
|
| 52 |
+
|
| 53 |
+
start = text.find("{")
|
| 54 |
+
end = text.rfind("}")
|
| 55 |
+
if start == -1 or end == -1 or end < start:
|
| 56 |
+
raise ValueError("Model output did not contain a JSON object.")
|
| 57 |
+
return json.loads(text[start : end + 1])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def moving_average(values: Sequence[float], window: int = 10) -> list[float]:
|
| 61 |
+
"""Compute a simple moving average."""
|
| 62 |
+
|
| 63 |
+
if not values:
|
| 64 |
+
return []
|
| 65 |
+
window = max(1, int(window))
|
| 66 |
+
averaged: list[float] = []
|
| 67 |
+
for idx in range(len(values)):
|
| 68 |
+
start = max(0, idx - window + 1)
|
| 69 |
+
chunk = values[start : idx + 1]
|
| 70 |
+
averaged.append(sum(chunk) / len(chunk))
|
| 71 |
+
return averaged
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def approximate_token_count(text: str) -> int:
|
| 75 |
+
"""Cheap token estimate that works without a tokenizer."""
|
| 76 |
+
|
| 77 |
+
stripped = text.strip()
|
| 78 |
+
if not stripped:
|
| 79 |
+
return 0
|
| 80 |
+
return max(1, int(len(stripped.split()) * 1.3))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def flatten_dict_items(mapping: dict[str, Any], prefix: str = "") -> Iterable[tuple[str, Any]]:
|
| 84 |
+
"""Flatten nested dictionaries for logging."""
|
| 85 |
+
|
| 86 |
+
for key, value in mapping.items():
|
| 87 |
+
full_key = f"{prefix}.{key}" if prefix else key
|
| 88 |
+
if isinstance(value, dict):
|
| 89 |
+
yield from flatten_dict_items(value, prefix=full_key)
|
| 90 |
+
else:
|
| 91 |
+
yield full_key, value
|