fraudshield-1 / train.py
DevikaJ2005's picture
Align local policy prompts and canonical action targets
45cf526
"""Training entrypoint for FraudShield Colab-friendly experiments."""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pandas as pd
from datasets import Dataset, load_dataset
from config import ExperimentConfig
from environment import ACTION_TYPE_TO_CANONICAL_ALIAS, FraudShieldTextEnvironment
from models import ActionTypeEnum, FraudCheckAction
from utils import ensure_dir, save_json, seed_everything
class ExpertCurriculumTeacher:
"""Teacher policy that uses hidden task structure to generate stronger trajectories."""
def decide(self, text_env: FraudShieldTextEnvironment) -> FraudCheckAction:
observation = text_env.current_observation
case_id = observation.case_id
revealed = observation.revealed_evidence
case = text_env.env.workflow_cases[case_id]
budget = int(observation.app_context.get("investigation_budget_remaining", 0))
if "transaction_review" not in revealed:
return FraudCheckAction(
case_id=case_id,
action_type=ActionTypeEnum.REVIEW_TRANSACTION,
reasoning="Open the transaction details before taking any deeper investigative step.",
)
planned_sequence = self._planned_evidence_sequence(case)
for evidence_key, action_type, reasoning in planned_sequence:
if evidence_key not in revealed and budget > 0:
return FraudCheckAction(case_id=case_id, action_type=action_type, reasoning=reasoning)
if observation.note_required:
return FraudCheckAction(
case_id=case_id,
action_type=ActionTypeEnum.ADD_CASE_NOTE,
note_text=self._case_note(case),
)
return FraudCheckAction(
case_id=case_id,
action_type=ActionTypeEnum.RESOLVE_CASE,
resolution=case["correct_resolution"],
reasoning=self._resolution_reasoning(case),
)
def _planned_evidence_sequence(self, case: dict[str, Any]) -> list[tuple[str, ActionTypeEnum, str]]:
role = case["role"]
task_specific = [
(
"customer_profile",
ActionTypeEnum.FETCH_CUSTOMER_PROFILE,
"Customer history is needed to understand whether this pattern reflects risky buyer behavior.",
),
(
"merchant_profile",
ActionTypeEnum.FETCH_MERCHANT_PROFILE,
"Merchant health helps explain whether the case risk comes from the seller side.",
),
(
"network_graph",
ActionTypeEnum.FETCH_NETWORK_GRAPH,
"Linked-activity evidence is needed to confirm whether this case participates in a broader cluster.",
),
(
"policy_guide",
ActionTypeEnum.CHECK_POLICY,
"Policy guidance is required before choosing the final route.",
),
]
if role == "single" and case["correct_resolution"].value == "request_docs":
return [
task_specific[0],
task_specific[3],
task_specific[1],
]
if role == "primary":
return [
task_specific[2],
task_specific[1],
task_specific[3],
]
if role == "secondary":
return [
task_specific[2],
task_specific[0],
task_specific[3],
]
return [
task_specific[1],
]
def _case_note(self, case: dict[str, Any]) -> str:
if case["role"] == "primary":
return "Reviewed the transaction trace, graph evidence, merchant signals, and policy guidance before escalating the linked primary case."
if case["role"] == "secondary":
return "Reviewed the transaction trace, graph evidence, customer history, and policy guidance before finalizing the linked secondary case."
if case["correct_resolution"].value == "request_docs":
return "Reviewed transaction, customer, merchant, and policy evidence before requesting more supporting documents."
return "Reviewed the transaction evidence and documented the case before final routing."
def _resolution_reasoning(self, case: dict[str, Any]) -> str:
mapping = {
"approve": "The collected evidence supports approval without additional intervention.",
"block": "The combined evidence supports blocking the transaction as high risk.",
"hold": "The evidence remains risky enough to hold the case for more controlled handling.",
"request_docs": "The case is ambiguous enough that supporting documents are the safest next step.",
"escalate": "The linked-cluster evidence and loss risk justify escalation to a higher-touch reviewer.",
}
return mapping[case["correct_resolution"].value]
def build_public_curriculum(config: ExperimentConfig) -> Dataset:
"""Load public fraud examples and convert them into action-centric prompts."""
dataset_name = config.training.public_curriculum_dataset
dataset = load_dataset(dataset_name, split="train")
rows: list[dict[str, Any]] = []
for row in dataset.shuffle(seed=config.seed).select(
range(min(config.training.public_curriculum_rows, len(dataset)))
):
amount = float(row.get("amount", row.get("Amount", 0.0)) or 0.0)
label = int(row.get("is_fraud", row.get("isFraud", row.get("Class", 0))) or 0)
transaction_type = str(row.get("transaction_type", row.get("type", "purchase")))
prompt = (
"You are a fraud analyst learning to investigate risk before final routing. Return JSON only.\n\n"
f"Visible observation:\n{json.dumps({'amount_usd': amount, 'transaction_type': transaction_type, 'task_name': 'medium', 'available_investigations': ['merchant_profile', 'customer_profile', 'network_graph', 'payment_trace', 'policy_review']})}\n\n"
'JSON schema: {"action_type":"investigate|decide","investigation_target":"alias_or_null","decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}'
)
if label:
payload = {
"action_type": "investigate",
"investigation_target": "network_graph" if amount > 1000 else "payment_trace",
"decision": None,
"confidence": None,
"reasoning": "The visible transaction is risky, so gather stronger network or payment evidence first.",
}
else:
payload = {
"action_type": "decide",
"investigation_target": None,
"decision": "legitimate",
"confidence": 0.8,
"reasoning": "The visible transaction appears low risk and can be cleared confidently.",
}
rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "public"})
return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
def build_rollout_dataset(config: ExperimentConfig) -> Dataset:
"""Generate environment-compatible trajectories from an expert teacher."""
text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights)
agent = ExpertCurriculumTeacher()
rows: list[dict[str, Any]] = []
for task_name in config.evaluation.tasks:
for _ in range(config.training.warmstart_rollouts_per_task):
prompt = text_env.reset(task=task_name)
done = False
while not done:
action = agent.decide(text_env)
payload = {
"action_type": "decide" if action.action_type.value == "resolve_case" else "investigate",
"investigation_target": ACTION_TYPE_TO_CANONICAL_ALIAS.get(action.action_type),
"decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate",
"confidence": 0.8,
"reasoning": action.reasoning or "Training rollout step.",
}
rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "rollout"})
step = text_env.step(json.dumps(payload))
prompt = step.next_prompt
done = step.done
return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False)
def load_model_stack(config: ExperimentConfig):
"""Load a Colab-friendly 4-bit LoRA stack."""
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.model.base_model,
max_seq_length=config.model.max_seq_length,
load_in_4bit=config.model.load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
r=config.model.lora_rank,
lora_alpha=config.model.lora_alpha,
lora_dropout=config.model.lora_dropout,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
use_gradient_checkpointing=config.model.gradient_checkpointing,
)
return model, tokenizer
def run_training(config: ExperimentConfig) -> dict[str, Any]:
"""Run the configured training pipeline."""
seed_everything(config.seed)
ensure_dir(config.training.output_dir)
ensure_dir(config.training.checkpoint_dir)
if "wandb" in config.training.report_to and not os.environ.get("WANDB_PROJECT"):
os.environ["WANDB_PROJECT"] = "fraudshield"
if "tensorboard" in config.training.report_to:
ensure_dir(Path(config.training.output_dir) / "tb_logs")
public_dataset = build_public_curriculum(config)
rollout_dataset = build_rollout_dataset(config)
model, tokenizer = load_model_stack(config)
from transformers import TrainingArguments
from trl import SFTTrainer
stage1_args = TrainingArguments(
output_dir=str(Path(config.training.output_dir) / "stage1"),
num_train_epochs=1,
per_device_train_batch_size=config.training.per_device_train_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
learning_rate=config.training.learning_rate * 2,
logging_steps=max(1, config.training.logging_steps),
save_strategy="no",
report_to=config.training.report_to,
run_name=f"{config.training.run_name}-stage1",
logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage1"),
)
stage1_trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=public_dataset,
dataset_text_field="text",
max_seq_length=config.model.max_seq_length,
packing=False,
args=stage1_args,
)
stage1_trainer.train()
stage2_args = TrainingArguments(
output_dir=str(Path(config.training.output_dir) / "stage2"),
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
learning_rate=config.training.learning_rate,
logging_steps=max(1, config.training.logging_steps),
save_strategy="epoch",
report_to=config.training.report_to,
run_name=f"{config.training.run_name}-stage2",
logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage2"),
)
trainer = SFTTrainer(
model=stage1_trainer.model,
tokenizer=tokenizer,
train_dataset=rollout_dataset,
dataset_text_field="text",
max_seq_length=config.model.max_seq_length,
packing=False,
args=stage2_args,
)
trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
output_dir = Path(config.training.output_dir) / "trained_policy"
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
log_history = trainer.state.log_history
loss_points = [(entry["step"], entry["loss"]) for entry in log_history if "step" in entry and "loss" in entry]
if loss_points:
xs, ys = zip(*loss_points)
plt.figure(figsize=(8, 4))
plt.plot(xs, ys)
plt.xlabel("training step")
plt.ylabel("loss")
plt.tight_layout()
plt.savefig(Path(config.training.output_dir) / "loss_vs_steps.png")
plt.close()
reward_trace = []
for idx, entry in enumerate(log_history, start=1):
if "loss" in entry:
reward_trace.append(max(0.0, 1.0 - float(entry["loss"])))
if reward_trace:
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(reward_trace) + 1), reward_trace, label="reward_proxy")
window = min(10, len(reward_trace))
if window:
from utils import moving_average
plt.plot(range(1, len(reward_trace) + 1), moving_average(reward_trace, window=window), label="moving_avg")
plt.xlabel("training step")
plt.ylabel("reward proxy")
plt.legend()
plt.tight_layout()
plt.savefig(Path(config.training.output_dir) / "reward_vs_steps.png")
plt.close()
metadata = {
"status": "completed",
"algorithm": config.training.algorithm,
"warmstart_algorithm": config.training.warmstart_algorithm,
"report_to": config.training.report_to,
"run_name": config.training.run_name,
"public_curriculum_dataset": config.training.public_curriculum_dataset,
"output_dir": str(output_dir),
"num_public_examples": len(public_dataset),
"num_rollout_examples": len(rollout_dataset),
"log_history": log_history,
}
save_json(metadata, Path(config.training.output_dir) / "training_run_summary.json")
return metadata
def main() -> None:
parser = argparse.ArgumentParser(description="Train FraudShield with a Colab-friendly curriculum.")
parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.")
args = parser.parse_args()
config = ExperimentConfig.load(args.config)
config.save(Path(config.training.output_dir) / "resolved_config.json")
summary = run_training(config)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()