| | """ |
| | Stage 2: InfoNCE Fine-tuning for ExecutionEncoder |
| | |
| | Loads the Stage 1 VICReg checkpoint and fine-tunes with InfoNCE loss using |
| | (anchor=benign, positive=augmented_benign, negatives=adversarial_in_batch). |
| | |
| | This creates the energy gap between benign and adversarial execution plans |
| | that Stage 1 (VICReg geometry) could not produce alone. |
| | |
| | Usage: |
| | uv run python scripts/train_stage2_infonce.py \ |
| | --dataset data/adversarial_563k.jsonl \ |
| | --checkpoint outputs/execution_encoder_50k/encoder_final.pt \ |
| | --max-samples 50000 \ |
| | --epochs 3 \ |
| | --batch-size 32 \ |
| | --device mps \ |
| | --output-dir outputs/execution_encoder_stage2 |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import math |
| | import random |
| | import sys |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| | from tqdm import tqdm |
| |
|
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| | from source.encoders.execution_encoder import ExecutionEncoder, ExecutionPlan |
| |
|
| |
|
| | |
| |
|
| | class AdversarialPairDataset(Dataset): |
| | """ |
| | Loads adversarial_563k.jsonl and separates benign / adversarial samples. |
| | Each __getitem__ returns one sample dict with its label. |
| | """ |
| |
|
| | def __init__(self, path: str, max_samples: int | None = None): |
| | self.benign: list[dict] = [] |
| | self.adversarial: list[dict] = [] |
| |
|
| | with open(path) as f: |
| | for i, line in enumerate(f): |
| | if max_samples and i >= max_samples: |
| | break |
| | sample = json.loads(line) |
| | if sample["label"] == "adversarial": |
| | self.adversarial.append(sample["execution_plan"]) |
| | else: |
| | self.benign.append(sample["execution_plan"]) |
| |
|
| | print(f" π Benign: {len(self.benign):,} | Adversarial: {len(self.adversarial):,}") |
| | if not self.adversarial: |
| | raise ValueError("No adversarial samples found β check dataset labels") |
| |
|
| | def __len__(self) -> int: |
| | return len(self.benign) |
| |
|
| | def __getitem__(self, idx: int) -> dict[str, Any]: |
| | return {"benign": self.benign[idx], "adversarial": random.choice(self.adversarial)} |
| |
|
| |
|
| | def collate_pairs(batch: list[dict]) -> dict[str, list]: |
| | """Return lists of plan dicts, bypass default tensor stacking.""" |
| | return { |
| | "benign": [item["benign"] for item in batch], |
| | "adversarial": [item["adversarial"] for item in batch], |
| | } |
| |
|
| |
|
| | |
| |
|
| | def augment_plan(plan_dict: dict) -> dict: |
| | """ |
| | Light stochastic augmentation of a benign plan to create positives. |
| | Only modifies metadata fields, never changes semantic content. |
| | """ |
| | import copy |
| | plan = copy.deepcopy(plan_dict) |
| | for node in plan.get("nodes", []): |
| | |
| | if random.random() < 0.3: |
| | node["scope_volume"] = max(1, int(node.get("scope_volume", 1) * random.uniform(0.8, 1.2))) |
| | |
| | if random.random() < 0.2 and node.get("arguments"): |
| | args = node["arguments"] |
| | keys = list(args.keys()) |
| | if keys: |
| | drop_key = random.choice(keys) |
| | args.pop(drop_key) |
| | return plan |
| |
|
| |
|
| | |
| |
|
| | class InfoNCELoss(nn.Module): |
| | """ |
| | InfoNCE (NT-Xent) contrastive loss. |
| | |
| | For each anchor (benign), the positive is its augmented version, |
| | and all adversarial samples in the batch are negatives. |
| | |
| | Loss = -log( exp(sim(anchor, pos) / tau) / |
| | sum(exp(sim(anchor, neg_i) / tau) for neg_i in batch) ) |
| | |
| | Lower temperature Ο β sharper decision boundary. |
| | """ |
| |
|
| | def __init__(self, temperature: float = 0.07): |
| | super().__init__() |
| | self.tau = temperature |
| |
|
| | def forward( |
| | self, |
| | anchors: torch.Tensor, |
| | positives: torch.Tensor, |
| | negatives: torch.Tensor, |
| | ) -> tuple[torch.Tensor, dict[str, float]]: |
| | B = anchors.size(0) |
| |
|
| | |
| | anchors = F.normalize(anchors, dim=-1) |
| | positives = F.normalize(positives, dim=-1) |
| | negatives = F.normalize(negatives, dim=-1) |
| |
|
| | |
| | pos_sim = (anchors * positives).sum(dim=-1) / self.tau |
| |
|
| | |
| | neg_sim = torch.matmul(anchors, negatives.T) / self.tau |
| |
|
| | |
| | |
| | logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) |
| | labels = torch.zeros(B, dtype=torch.long, device=anchors.device) |
| |
|
| | loss = F.cross_entropy(logits, labels) |
| |
|
| | |
| | with torch.no_grad(): |
| | pos_cosim = (anchors * positives).sum(dim=-1).mean().item() |
| | neg_cosim = (anchors * negatives).sum(dim=-1).mean().item() |
| | energy_gap = pos_cosim - neg_cosim |
| |
|
| | return loss, { |
| | "pos_cosim": pos_cosim, |
| | "neg_cosim": neg_cosim, |
| | "energy_gap": energy_gap, |
| | } |
| |
|
| |
|
| | |
| |
|
| | def train_stage2( |
| | dataset_path: str, |
| | checkpoint_path: str, |
| | output_dir: str, |
| | max_samples: int | None, |
| | epochs: int, |
| | batch_size: int, |
| | lr: float, |
| | temperature: float, |
| | device: str, |
| | save_every: int, |
| | ) -> None: |
| | Path(output_dir).mkdir(parents=True, exist_ok=True) |
| |
|
| | print("π§ Stage 2: InfoNCE Fine-tuning") |
| | print(f" Checkpoint : {checkpoint_path}") |
| | print(f" Dataset : {dataset_path}") |
| | print(f" Device : {device}") |
| | print(f" Temperature: {temperature}") |
| | print(f" Max samples: {max_samples or 'all'}") |
| |
|
| | |
| | model = ExecutionEncoder(latent_dim=1024) |
| | state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| | model.load_state_dict(state) |
| | model = model.to(device) |
| | model.train() |
| | print(f" β
Loaded Stage 1 checkpoint ({sum(p.numel() for p in model.parameters()):,} params)") |
| |
|
| | |
| | dataset = AdversarialPairDataset(dataset_path, max_samples=max_samples) |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | collate_fn=collate_pairs, |
| | num_workers=0, |
| | drop_last=True, |
| | ) |
| | print(f" π¦ Batches per epoch: {len(loader)}") |
| |
|
| | criterion = InfoNCELoss(temperature=temperature) |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) |
| |
|
| | |
| | warmup_steps = min(100, len(loader)) |
| | total_steps = len(loader) * epochs |
| |
|
| | def lr_lambda(step: int) -> float: |
| | if step < warmup_steps: |
| | return step / max(1, warmup_steps) |
| | progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| | return max(0.1, 0.5 * (1 + math.cos(math.pi * progress))) |
| |
|
| | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| |
|
| | global_step = 0 |
| | for epoch in range(1, epochs + 1): |
| | epoch_loss = 0.0 |
| | epoch_gap = 0.0 |
| | n_batches = 0 |
| |
|
| | pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", dynamic_ncols=True) |
| | for batch in pbar: |
| | benign_plans = batch["benign"] |
| | adversarial_plans = batch["adversarial"] |
| |
|
| | |
| | augmented_plans = [augment_plan(p) for p in benign_plans] |
| |
|
| | |
| | try: |
| | anchors = torch.cat([model(p) for p in benign_plans], dim=0) |
| | positives = torch.cat([model(p) for p in augmented_plans], dim=0) |
| | negatives = torch.cat([model(p) for p in adversarial_plans], dim=0) |
| | except Exception as e: |
| | print(f"\nβ οΈ Batch encode error: {e}") |
| | continue |
| |
|
| | loss, metrics = criterion(anchors, positives, negatives) |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| |
|
| | epoch_loss += loss.item() |
| | epoch_gap += metrics["energy_gap"] |
| | n_batches += 1 |
| | global_step += 1 |
| |
|
| | pbar.set_postfix( |
| | loss=f"{loss.item():.4f}", |
| | gap=f"{metrics['energy_gap']:.4f}", |
| | pos=f"{metrics['pos_cosim']:.3f}", |
| | neg=f"{metrics['neg_cosim']:.3f}", |
| | ) |
| |
|
| | avg_loss = epoch_loss / max(1, n_batches) |
| | avg_gap = epoch_gap / max(1, n_batches) |
| | print(f"\n Epoch {epoch} | avg_loss={avg_loss:.4f} | avg_energy_gap={avg_gap:.4f}") |
| |
|
| | if epoch % save_every == 0: |
| | ckpt = Path(output_dir) / f"encoder_stage2_epoch_{epoch}.pt" |
| | torch.save(model.state_dict(), ckpt) |
| | print(f" πΎ Saved checkpoint: {ckpt}") |
| |
|
| | |
| | final_path = Path(output_dir) / "encoder_stage2_final.pt" |
| | torch.save(model.state_dict(), final_path) |
| | print(f"\nβ
Stage 2 Training Complete!") |
| | print(f" Final model: {final_path}") |
| |
|
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Stage 2 InfoNCE fine-tuning") |
| | parser.add_argument("--dataset", required=True, help="Path to adversarial_563k.jsonl") |
| | parser.add_argument("--checkpoint", required=True, help="Path to Stage 1 checkpoint") |
| | parser.add_argument("--output-dir", default="outputs/execution_encoder_stage2") |
| | parser.add_argument("--max-samples", type=int, default=None) |
| | parser.add_argument("--epochs", type=int, default=3) |
| | parser.add_argument("--batch-size", type=int, default=32) |
| | parser.add_argument("--lr", type=float, default=1e-4) |
| | parser.add_argument("--temperature", type=float, default=0.07) |
| | parser.add_argument("--device", choices=["cpu", "cuda", "mps"], default="cpu") |
| | parser.add_argument("--save-every", type=int, default=1) |
| | args = parser.parse_args() |
| |
|
| | train_stage2( |
| | dataset_path=args.dataset, |
| | checkpoint_path=args.checkpoint, |
| | output_dir=args.output_dir, |
| | max_samples=args.max_samples, |
| | epochs=args.epochs, |
| | batch_size=args.batch_size, |
| | lr=args.lr, |
| | temperature=args.temperature, |
| | device=args.device, |
| | save_every=args.save_every, |
| | ) |
| |
|