Spaces:
Sleeping
Sleeping
| """Training entrypoint for FraudShield Colab-friendly experiments.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from datasets import Dataset, load_dataset | |
| from config import ExperimentConfig | |
| from environment import ACTION_TYPE_TO_CANONICAL_ALIAS, FraudShieldTextEnvironment | |
| from models import ActionTypeEnum, FraudCheckAction | |
| from utils import ensure_dir, save_json, seed_everything | |
| class ExpertCurriculumTeacher: | |
| """Teacher policy that uses hidden task structure to generate stronger trajectories.""" | |
| def decide(self, text_env: FraudShieldTextEnvironment) -> FraudCheckAction: | |
| observation = text_env.current_observation | |
| case_id = observation.case_id | |
| revealed = observation.revealed_evidence | |
| case = text_env.env.workflow_cases[case_id] | |
| budget = int(observation.app_context.get("investigation_budget_remaining", 0)) | |
| if "transaction_review" not in revealed: | |
| return FraudCheckAction( | |
| case_id=case_id, | |
| action_type=ActionTypeEnum.REVIEW_TRANSACTION, | |
| reasoning="Open the transaction details before taking any deeper investigative step.", | |
| ) | |
| planned_sequence = self._planned_evidence_sequence(case) | |
| for evidence_key, action_type, reasoning in planned_sequence: | |
| if evidence_key not in revealed and budget > 0: | |
| return FraudCheckAction(case_id=case_id, action_type=action_type, reasoning=reasoning) | |
| if observation.note_required: | |
| return FraudCheckAction( | |
| case_id=case_id, | |
| action_type=ActionTypeEnum.ADD_CASE_NOTE, | |
| note_text=self._case_note(case), | |
| ) | |
| return FraudCheckAction( | |
| case_id=case_id, | |
| action_type=ActionTypeEnum.RESOLVE_CASE, | |
| resolution=case["correct_resolution"], | |
| reasoning=self._resolution_reasoning(case), | |
| ) | |
| def _planned_evidence_sequence(self, case: dict[str, Any]) -> list[tuple[str, ActionTypeEnum, str]]: | |
| role = case["role"] | |
| task_specific = [ | |
| ( | |
| "customer_profile", | |
| ActionTypeEnum.FETCH_CUSTOMER_PROFILE, | |
| "Customer history is needed to understand whether this pattern reflects risky buyer behavior.", | |
| ), | |
| ( | |
| "merchant_profile", | |
| ActionTypeEnum.FETCH_MERCHANT_PROFILE, | |
| "Merchant health helps explain whether the case risk comes from the seller side.", | |
| ), | |
| ( | |
| "network_graph", | |
| ActionTypeEnum.FETCH_NETWORK_GRAPH, | |
| "Linked-activity evidence is needed to confirm whether this case participates in a broader cluster.", | |
| ), | |
| ( | |
| "policy_guide", | |
| ActionTypeEnum.CHECK_POLICY, | |
| "Policy guidance is required before choosing the final route.", | |
| ), | |
| ] | |
| if role == "single" and case["correct_resolution"].value == "request_docs": | |
| return [ | |
| task_specific[0], | |
| task_specific[3], | |
| task_specific[1], | |
| ] | |
| if role == "primary": | |
| return [ | |
| task_specific[2], | |
| task_specific[1], | |
| task_specific[3], | |
| ] | |
| if role == "secondary": | |
| return [ | |
| task_specific[2], | |
| task_specific[0], | |
| task_specific[3], | |
| ] | |
| return [ | |
| task_specific[1], | |
| ] | |
| def _case_note(self, case: dict[str, Any]) -> str: | |
| if case["role"] == "primary": | |
| return "Reviewed the transaction trace, graph evidence, merchant signals, and policy guidance before escalating the linked primary case." | |
| if case["role"] == "secondary": | |
| return "Reviewed the transaction trace, graph evidence, customer history, and policy guidance before finalizing the linked secondary case." | |
| if case["correct_resolution"].value == "request_docs": | |
| return "Reviewed transaction, customer, merchant, and policy evidence before requesting more supporting documents." | |
| return "Reviewed the transaction evidence and documented the case before final routing." | |
| def _resolution_reasoning(self, case: dict[str, Any]) -> str: | |
| mapping = { | |
| "approve": "The collected evidence supports approval without additional intervention.", | |
| "block": "The combined evidence supports blocking the transaction as high risk.", | |
| "hold": "The evidence remains risky enough to hold the case for more controlled handling.", | |
| "request_docs": "The case is ambiguous enough that supporting documents are the safest next step.", | |
| "escalate": "The linked-cluster evidence and loss risk justify escalation to a higher-touch reviewer.", | |
| } | |
| return mapping[case["correct_resolution"].value] | |
| def build_public_curriculum(config: ExperimentConfig) -> Dataset: | |
| """Load public fraud examples and convert them into action-centric prompts.""" | |
| dataset_name = config.training.public_curriculum_dataset | |
| dataset = load_dataset(dataset_name, split="train") | |
| rows: list[dict[str, Any]] = [] | |
| for row in dataset.shuffle(seed=config.seed).select( | |
| range(min(config.training.public_curriculum_rows, len(dataset))) | |
| ): | |
| amount = float(row.get("amount", row.get("Amount", 0.0)) or 0.0) | |
| label = int(row.get("is_fraud", row.get("isFraud", row.get("Class", 0))) or 0) | |
| transaction_type = str(row.get("transaction_type", row.get("type", "purchase"))) | |
| prompt = ( | |
| "You are a fraud analyst learning to investigate risk before final routing. Return JSON only.\n\n" | |
| f"Visible observation:\n{json.dumps({'amount_usd': amount, 'transaction_type': transaction_type, 'task_name': 'medium', 'available_investigations': ['merchant_profile', 'customer_profile', 'network_graph', 'payment_trace', 'policy_review']})}\n\n" | |
| 'JSON schema: {"action_type":"investigate|decide","investigation_target":"alias_or_null","decision":"fraud|legitimate|null","confidence":0.0,"reasoning":"one sentence"}' | |
| ) | |
| if label: | |
| payload = { | |
| "action_type": "investigate", | |
| "investigation_target": "network_graph" if amount > 1000 else "payment_trace", | |
| "decision": None, | |
| "confidence": None, | |
| "reasoning": "The visible transaction is risky, so gather stronger network or payment evidence first.", | |
| } | |
| else: | |
| payload = { | |
| "action_type": "decide", | |
| "investigation_target": None, | |
| "decision": "legitimate", | |
| "confidence": 0.8, | |
| "reasoning": "The visible transaction appears low risk and can be cleared confidently.", | |
| } | |
| rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "public"}) | |
| return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False) | |
| def build_rollout_dataset(config: ExperimentConfig) -> Dataset: | |
| """Generate environment-compatible trajectories from an expert teacher.""" | |
| text_env = FraudShieldTextEnvironment(config.environment, config.reward_weights) | |
| agent = ExpertCurriculumTeacher() | |
| rows: list[dict[str, Any]] = [] | |
| for task_name in config.evaluation.tasks: | |
| for _ in range(config.training.warmstart_rollouts_per_task): | |
| prompt = text_env.reset(task=task_name) | |
| done = False | |
| while not done: | |
| action = agent.decide(text_env) | |
| payload = { | |
| "action_type": "decide" if action.action_type.value == "resolve_case" else "investigate", | |
| "investigation_target": ACTION_TYPE_TO_CANONICAL_ALIAS.get(action.action_type), | |
| "decision": "fraud" if getattr(action, "resolution", None) and action.resolution.value in {"block", "hold", "escalate"} else "legitimate", | |
| "confidence": 0.8, | |
| "reasoning": action.reasoning or "Training rollout step.", | |
| } | |
| rows.append({"text": prompt + "\n" + json.dumps(payload, separators=(",", ":")), "source": "rollout"}) | |
| step = text_env.step(json.dumps(payload)) | |
| prompt = step.next_prompt | |
| done = step.done | |
| return Dataset.from_pandas(pd.DataFrame(rows), preserve_index=False) | |
| def load_model_stack(config: ExperimentConfig): | |
| """Load a Colab-friendly 4-bit LoRA stack.""" | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=config.model.base_model, | |
| max_seq_length=config.model.max_seq_length, | |
| load_in_4bit=config.model.load_in_4bit, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=config.model.lora_rank, | |
| lora_alpha=config.model.lora_alpha, | |
| lora_dropout=config.model.lora_dropout, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| use_gradient_checkpointing=config.model.gradient_checkpointing, | |
| ) | |
| return model, tokenizer | |
| def run_training(config: ExperimentConfig) -> dict[str, Any]: | |
| """Run the configured training pipeline.""" | |
| seed_everything(config.seed) | |
| ensure_dir(config.training.output_dir) | |
| ensure_dir(config.training.checkpoint_dir) | |
| if "wandb" in config.training.report_to and not os.environ.get("WANDB_PROJECT"): | |
| os.environ["WANDB_PROJECT"] = "fraudshield" | |
| if "tensorboard" in config.training.report_to: | |
| ensure_dir(Path(config.training.output_dir) / "tb_logs") | |
| public_dataset = build_public_curriculum(config) | |
| rollout_dataset = build_rollout_dataset(config) | |
| model, tokenizer = load_model_stack(config) | |
| from transformers import TrainingArguments | |
| from trl import SFTTrainer | |
| stage1_args = TrainingArguments( | |
| output_dir=str(Path(config.training.output_dir) / "stage1"), | |
| num_train_epochs=1, | |
| per_device_train_batch_size=config.training.per_device_train_batch_size, | |
| gradient_accumulation_steps=config.training.gradient_accumulation_steps, | |
| learning_rate=config.training.learning_rate * 2, | |
| logging_steps=max(1, config.training.logging_steps), | |
| save_strategy="no", | |
| report_to=config.training.report_to, | |
| run_name=f"{config.training.run_name}-stage1", | |
| logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage1"), | |
| ) | |
| stage1_trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=public_dataset, | |
| dataset_text_field="text", | |
| max_seq_length=config.model.max_seq_length, | |
| packing=False, | |
| args=stage1_args, | |
| ) | |
| stage1_trainer.train() | |
| stage2_args = TrainingArguments( | |
| output_dir=str(Path(config.training.output_dir) / "stage2"), | |
| num_train_epochs=config.training.num_train_epochs, | |
| per_device_train_batch_size=config.training.per_device_train_batch_size, | |
| gradient_accumulation_steps=config.training.gradient_accumulation_steps, | |
| learning_rate=config.training.learning_rate, | |
| logging_steps=max(1, config.training.logging_steps), | |
| save_strategy="epoch", | |
| report_to=config.training.report_to, | |
| run_name=f"{config.training.run_name}-stage2", | |
| logging_dir=str(Path(config.training.output_dir) / "tb_logs" / "stage2"), | |
| ) | |
| trainer = SFTTrainer( | |
| model=stage1_trainer.model, | |
| tokenizer=tokenizer, | |
| train_dataset=rollout_dataset, | |
| dataset_text_field="text", | |
| max_seq_length=config.model.max_seq_length, | |
| packing=False, | |
| args=stage2_args, | |
| ) | |
| trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint) | |
| output_dir = Path(config.training.output_dir) / "trained_policy" | |
| trainer.model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| log_history = trainer.state.log_history | |
| loss_points = [(entry["step"], entry["loss"]) for entry in log_history if "step" in entry and "loss" in entry] | |
| if loss_points: | |
| xs, ys = zip(*loss_points) | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(xs, ys) | |
| plt.xlabel("training step") | |
| plt.ylabel("loss") | |
| plt.tight_layout() | |
| plt.savefig(Path(config.training.output_dir) / "loss_vs_steps.png") | |
| plt.close() | |
| reward_trace = [] | |
| for idx, entry in enumerate(log_history, start=1): | |
| if "loss" in entry: | |
| reward_trace.append(max(0.0, 1.0 - float(entry["loss"]))) | |
| if reward_trace: | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(range(1, len(reward_trace) + 1), reward_trace, label="reward_proxy") | |
| window = min(10, len(reward_trace)) | |
| if window: | |
| from utils import moving_average | |
| plt.plot(range(1, len(reward_trace) + 1), moving_average(reward_trace, window=window), label="moving_avg") | |
| plt.xlabel("training step") | |
| plt.ylabel("reward proxy") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(Path(config.training.output_dir) / "reward_vs_steps.png") | |
| plt.close() | |
| metadata = { | |
| "status": "completed", | |
| "algorithm": config.training.algorithm, | |
| "warmstart_algorithm": config.training.warmstart_algorithm, | |
| "report_to": config.training.report_to, | |
| "run_name": config.training.run_name, | |
| "public_curriculum_dataset": config.training.public_curriculum_dataset, | |
| "output_dir": str(output_dir), | |
| "num_public_examples": len(public_dataset), | |
| "num_rollout_examples": len(rollout_dataset), | |
| "log_history": log_history, | |
| } | |
| save_json(metadata, Path(config.training.output_dir) / "training_run_summary.json") | |
| return metadata | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train FraudShield with a Colab-friendly curriculum.") | |
| parser.add_argument("--config", default="configs/colab_qlora_grpo.json", help="Path to experiment config JSON.") | |
| args = parser.parse_args() | |
| config = ExperimentConfig.load(args.config) | |
| config.save(Path(config.training.output_dir) / "resolved_config.json") | |
| summary = run_training(config) | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |