""" Stage 2: InfoNCE Fine-tuning for ExecutionEncoder Loads the Stage 1 VICReg checkpoint and fine-tunes with InfoNCE loss using (anchor=benign, positive=augmented_benign, negatives=adversarial_in_batch). This creates the energy gap between benign and adversarial execution plans that Stage 1 (VICReg geometry) could not produce alone. Usage: uv run python scripts/train_stage2_infonce.py \ --dataset data/adversarial_563k.jsonl \ --checkpoint outputs/execution_encoder_50k/encoder_final.pt \ --max-samples 50000 \ --epochs 3 \ --batch-size 32 \ --device mps \ --output-dir outputs/execution_encoder_stage2 """ import argparse import json import math import random import sys from pathlib import Path from typing import Any import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from tqdm import tqdm sys.path.insert(0, str(Path(__file__).parent.parent)) from source.encoders.execution_encoder import ExecutionEncoder, ExecutionPlan # ── Dataset ────────────────────────────────────────────────────────────────── class AdversarialPairDataset(Dataset): """ Loads adversarial_563k.jsonl and separates benign / adversarial samples. Each __getitem__ returns one sample dict with its label. """ def __init__(self, path: str, max_samples: int | None = None): self.benign: list[dict] = [] self.adversarial: list[dict] = [] with open(path) as f: for i, line in enumerate(f): if max_samples and i >= max_samples: break sample = json.loads(line) if sample["label"] == "adversarial": self.adversarial.append(sample["execution_plan"]) else: self.benign.append(sample["execution_plan"]) print(f" 📊 Benign: {len(self.benign):,} | Adversarial: {len(self.adversarial):,}") if not self.adversarial: raise ValueError("No adversarial samples found — check dataset labels") def __len__(self) -> int: return len(self.benign) def __getitem__(self, idx: int) -> dict[str, Any]: return {"benign": self.benign[idx], "adversarial": random.choice(self.adversarial)} def collate_pairs(batch: list[dict]) -> dict[str, list]: """Return lists of plan dicts, bypass default tensor stacking.""" return { "benign": [item["benign"] for item in batch], "adversarial": [item["adversarial"] for item in batch], } # ── Augmentation ───────────────────────────────────────────────────────────── def augment_plan(plan_dict: dict) -> dict: """ Light stochastic augmentation of a benign plan to create positives. Only modifies metadata fields, never changes semantic content. """ import copy plan = copy.deepcopy(plan_dict) for node in plan.get("nodes", []): # Randomly perturb scope_volume by ±20% (stays benign) if random.random() < 0.3: node["scope_volume"] = max(1, int(node.get("scope_volume", 1) * random.uniform(0.8, 1.2))) # Randomly drop/add an argument key (same tool, slight variation) if random.random() < 0.2 and node.get("arguments"): args = node["arguments"] keys = list(args.keys()) if keys: drop_key = random.choice(keys) args.pop(drop_key) return plan # ── InfoNCE Loss ───────────────────────────────────────────────────────────── class InfoNCELoss(nn.Module): """ InfoNCE (NT-Xent) contrastive loss. For each anchor (benign), the positive is its augmented version, and all adversarial samples in the batch are negatives. Loss = -log( exp(sim(anchor, pos) / tau) / sum(exp(sim(anchor, neg_i) / tau) for neg_i in batch) ) Lower temperature τ → sharper decision boundary. """ def __init__(self, temperature: float = 0.07): super().__init__() self.tau = temperature def forward( self, anchors: torch.Tensor, # [B, D] benign embeddings positives: torch.Tensor, # [B, D] augmented benign embeddings negatives: torch.Tensor, # [B, D] adversarial embeddings ) -> tuple[torch.Tensor, dict[str, float]]: B = anchors.size(0) # Normalize all embeddings to unit sphere anchors = F.normalize(anchors, dim=-1) positives = F.normalize(positives, dim=-1) negatives = F.normalize(negatives, dim=-1) # Positive similarity: anchor ↔ its augmented version pos_sim = (anchors * positives).sum(dim=-1) / self.tau # [B] # Negative similarities: each anchor vs all adversarials in batch neg_sim = torch.matmul(anchors, negatives.T) / self.tau # [B, B] # InfoNCE: softmax over [pos | all_negs] # logits: pos is at index 0, negs are indices 1..B logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # [B, B+1] labels = torch.zeros(B, dtype=torch.long, device=anchors.device) # pos at 0 loss = F.cross_entropy(logits, labels) # Diagnostics with torch.no_grad(): pos_cosim = (anchors * positives).sum(dim=-1).mean().item() neg_cosim = (anchors * negatives).sum(dim=-1).mean().item() energy_gap = pos_cosim - neg_cosim return loss, { "pos_cosim": pos_cosim, "neg_cosim": neg_cosim, "energy_gap": energy_gap, } # ── Training ───────────────────────────────────────────────────────────────── def train_stage2( dataset_path: str, checkpoint_path: str, output_dir: str, max_samples: int | None, epochs: int, batch_size: int, lr: float, temperature: float, device: str, save_every: int, ) -> None: Path(output_dir).mkdir(parents=True, exist_ok=True) print("🔧 Stage 2: InfoNCE Fine-tuning") print(f" Checkpoint : {checkpoint_path}") print(f" Dataset : {dataset_path}") print(f" Device : {device}") print(f" Temperature: {temperature}") print(f" Max samples: {max_samples or 'all'}") # Load Stage 1 checkpoint model = ExecutionEncoder(latent_dim=1024) state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(state) model = model.to(device) model.train() print(f" ✅ Loaded Stage 1 checkpoint ({sum(p.numel() for p in model.parameters()):,} params)") # Dataset dataset = AdversarialPairDataset(dataset_path, max_samples=max_samples) loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_pairs, num_workers=0, drop_last=True, # InfoNCE needs full batches ) print(f" 📦 Batches per epoch: {len(loader)}") criterion = InfoNCELoss(temperature=temperature) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) # Cosine LR schedule with warmup warmup_steps = min(100, len(loader)) total_steps = len(loader) * epochs def lr_lambda(step: int) -> float: if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return max(0.1, 0.5 * (1 + math.cos(math.pi * progress))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) global_step = 0 for epoch in range(1, epochs + 1): epoch_loss = 0.0 epoch_gap = 0.0 n_batches = 0 pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", dynamic_ncols=True) for batch in pbar: benign_plans = batch["benign"] adversarial_plans = batch["adversarial"] # Create augmented positives augmented_plans = [augment_plan(p) for p in benign_plans] # Encode all three sets try: anchors = torch.cat([model(p) for p in benign_plans], dim=0) positives = torch.cat([model(p) for p in augmented_plans], dim=0) negatives = torch.cat([model(p) for p in adversarial_plans], dim=0) except Exception as e: print(f"\n⚠️ Batch encode error: {e}") continue loss, metrics = criterion(anchors, positives, negatives) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() epoch_loss += loss.item() epoch_gap += metrics["energy_gap"] n_batches += 1 global_step += 1 pbar.set_postfix( loss=f"{loss.item():.4f}", gap=f"{metrics['energy_gap']:.4f}", pos=f"{metrics['pos_cosim']:.3f}", neg=f"{metrics['neg_cosim']:.3f}", ) avg_loss = epoch_loss / max(1, n_batches) avg_gap = epoch_gap / max(1, n_batches) print(f"\n Epoch {epoch} | avg_loss={avg_loss:.4f} | avg_energy_gap={avg_gap:.4f}") if epoch % save_every == 0: ckpt = Path(output_dir) / f"encoder_stage2_epoch_{epoch}.pt" torch.save(model.state_dict(), ckpt) print(f" 💾 Saved checkpoint: {ckpt}") # Save final final_path = Path(output_dir) / "encoder_stage2_final.pt" torch.save(model.state_dict(), final_path) print(f"\n✅ Stage 2 Training Complete!") print(f" Final model: {final_path}") # ── CLI ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser(description="Stage 2 InfoNCE fine-tuning") parser.add_argument("--dataset", required=True, help="Path to adversarial_563k.jsonl") parser.add_argument("--checkpoint", required=True, help="Path to Stage 1 checkpoint") parser.add_argument("--output-dir", default="outputs/execution_encoder_stage2") parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--temperature", type=float, default=0.07) parser.add_argument("--device", choices=["cpu", "cuda", "mps"], default="cpu") parser.add_argument("--save-every", type=int, default=1) args = parser.parse_args() train_stage2( dataset_path=args.dataset, checkpoint_path=args.checkpoint, output_dir=args.output_dir, max_samples=args.max_samples, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, temperature=args.temperature, device=args.device, save_every=args.save_every, )