gatling-execution-encoder / scripts /train_stage2_infonce.py
guychuk's picture
feat: add Stage 2 InfoNCE training script
3fe242c verified
"""
Stage 2: InfoNCE Fine-tuning for ExecutionEncoder
Loads the Stage 1 VICReg checkpoint and fine-tunes with InfoNCE loss using
(anchor=benign, positive=augmented_benign, negatives=adversarial_in_batch).
This creates the energy gap between benign and adversarial execution plans
that Stage 1 (VICReg geometry) could not produce alone.
Usage:
uv run python scripts/train_stage2_infonce.py \
--dataset data/adversarial_563k.jsonl \
--checkpoint outputs/execution_encoder_50k/encoder_final.pt \
--max-samples 50000 \
--epochs 3 \
--batch-size 32 \
--device mps \
--output-dir outputs/execution_encoder_stage2
"""
import argparse
import json
import math
import random
import sys
from pathlib import Path
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent))
from source.encoders.execution_encoder import ExecutionEncoder, ExecutionPlan
# ── Dataset ──────────────────────────────────────────────────────────────────
class AdversarialPairDataset(Dataset):
"""
Loads adversarial_563k.jsonl and separates benign / adversarial samples.
Each __getitem__ returns one sample dict with its label.
"""
def __init__(self, path: str, max_samples: int | None = None):
self.benign: list[dict] = []
self.adversarial: list[dict] = []
with open(path) as f:
for i, line in enumerate(f):
if max_samples and i >= max_samples:
break
sample = json.loads(line)
if sample["label"] == "adversarial":
self.adversarial.append(sample["execution_plan"])
else:
self.benign.append(sample["execution_plan"])
print(f" πŸ“Š Benign: {len(self.benign):,} | Adversarial: {len(self.adversarial):,}")
if not self.adversarial:
raise ValueError("No adversarial samples found β€” check dataset labels")
def __len__(self) -> int:
return len(self.benign)
def __getitem__(self, idx: int) -> dict[str, Any]:
return {"benign": self.benign[idx], "adversarial": random.choice(self.adversarial)}
def collate_pairs(batch: list[dict]) -> dict[str, list]:
"""Return lists of plan dicts, bypass default tensor stacking."""
return {
"benign": [item["benign"] for item in batch],
"adversarial": [item["adversarial"] for item in batch],
}
# ── Augmentation ─────────────────────────────────────────────────────────────
def augment_plan(plan_dict: dict) -> dict:
"""
Light stochastic augmentation of a benign plan to create positives.
Only modifies metadata fields, never changes semantic content.
"""
import copy
plan = copy.deepcopy(plan_dict)
for node in plan.get("nodes", []):
# Randomly perturb scope_volume by Β±20% (stays benign)
if random.random() < 0.3:
node["scope_volume"] = max(1, int(node.get("scope_volume", 1) * random.uniform(0.8, 1.2)))
# Randomly drop/add an argument key (same tool, slight variation)
if random.random() < 0.2 and node.get("arguments"):
args = node["arguments"]
keys = list(args.keys())
if keys:
drop_key = random.choice(keys)
args.pop(drop_key)
return plan
# ── InfoNCE Loss ─────────────────────────────────────────────────────────────
class InfoNCELoss(nn.Module):
"""
InfoNCE (NT-Xent) contrastive loss.
For each anchor (benign), the positive is its augmented version,
and all adversarial samples in the batch are negatives.
Loss = -log( exp(sim(anchor, pos) / tau) /
sum(exp(sim(anchor, neg_i) / tau) for neg_i in batch) )
Lower temperature Ο„ β†’ sharper decision boundary.
"""
def __init__(self, temperature: float = 0.07):
super().__init__()
self.tau = temperature
def forward(
self,
anchors: torch.Tensor, # [B, D] benign embeddings
positives: torch.Tensor, # [B, D] augmented benign embeddings
negatives: torch.Tensor, # [B, D] adversarial embeddings
) -> tuple[torch.Tensor, dict[str, float]]:
B = anchors.size(0)
# Normalize all embeddings to unit sphere
anchors = F.normalize(anchors, dim=-1)
positives = F.normalize(positives, dim=-1)
negatives = F.normalize(negatives, dim=-1)
# Positive similarity: anchor ↔ its augmented version
pos_sim = (anchors * positives).sum(dim=-1) / self.tau # [B]
# Negative similarities: each anchor vs all adversarials in batch
neg_sim = torch.matmul(anchors, negatives.T) / self.tau # [B, B]
# InfoNCE: softmax over [pos | all_negs]
# logits: pos is at index 0, negs are indices 1..B
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # [B, B+1]
labels = torch.zeros(B, dtype=torch.long, device=anchors.device) # pos at 0
loss = F.cross_entropy(logits, labels)
# Diagnostics
with torch.no_grad():
pos_cosim = (anchors * positives).sum(dim=-1).mean().item()
neg_cosim = (anchors * negatives).sum(dim=-1).mean().item()
energy_gap = pos_cosim - neg_cosim
return loss, {
"pos_cosim": pos_cosim,
"neg_cosim": neg_cosim,
"energy_gap": energy_gap,
}
# ── Training ─────────────────────────────────────────────────────────────────
def train_stage2(
dataset_path: str,
checkpoint_path: str,
output_dir: str,
max_samples: int | None,
epochs: int,
batch_size: int,
lr: float,
temperature: float,
device: str,
save_every: int,
) -> None:
Path(output_dir).mkdir(parents=True, exist_ok=True)
print("πŸ”§ Stage 2: InfoNCE Fine-tuning")
print(f" Checkpoint : {checkpoint_path}")
print(f" Dataset : {dataset_path}")
print(f" Device : {device}")
print(f" Temperature: {temperature}")
print(f" Max samples: {max_samples or 'all'}")
# Load Stage 1 checkpoint
model = ExecutionEncoder(latent_dim=1024)
state = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(state)
model = model.to(device)
model.train()
print(f" βœ… Loaded Stage 1 checkpoint ({sum(p.numel() for p in model.parameters()):,} params)")
# Dataset
dataset = AdversarialPairDataset(dataset_path, max_samples=max_samples)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_pairs,
num_workers=0,
drop_last=True, # InfoNCE needs full batches
)
print(f" πŸ“¦ Batches per epoch: {len(loader)}")
criterion = InfoNCELoss(temperature=temperature)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
# Cosine LR schedule with warmup
warmup_steps = min(100, len(loader))
total_steps = len(loader) * epochs
def lr_lambda(step: int) -> float:
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
global_step = 0
for epoch in range(1, epochs + 1):
epoch_loss = 0.0
epoch_gap = 0.0
n_batches = 0
pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", dynamic_ncols=True)
for batch in pbar:
benign_plans = batch["benign"]
adversarial_plans = batch["adversarial"]
# Create augmented positives
augmented_plans = [augment_plan(p) for p in benign_plans]
# Encode all three sets
try:
anchors = torch.cat([model(p) for p in benign_plans], dim=0)
positives = torch.cat([model(p) for p in augmented_plans], dim=0)
negatives = torch.cat([model(p) for p in adversarial_plans], dim=0)
except Exception as e:
print(f"\n⚠️ Batch encode error: {e}")
continue
loss, metrics = criterion(anchors, positives, negatives)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
epoch_gap += metrics["energy_gap"]
n_batches += 1
global_step += 1
pbar.set_postfix(
loss=f"{loss.item():.4f}",
gap=f"{metrics['energy_gap']:.4f}",
pos=f"{metrics['pos_cosim']:.3f}",
neg=f"{metrics['neg_cosim']:.3f}",
)
avg_loss = epoch_loss / max(1, n_batches)
avg_gap = epoch_gap / max(1, n_batches)
print(f"\n Epoch {epoch} | avg_loss={avg_loss:.4f} | avg_energy_gap={avg_gap:.4f}")
if epoch % save_every == 0:
ckpt = Path(output_dir) / f"encoder_stage2_epoch_{epoch}.pt"
torch.save(model.state_dict(), ckpt)
print(f" πŸ’Ύ Saved checkpoint: {ckpt}")
# Save final
final_path = Path(output_dir) / "encoder_stage2_final.pt"
torch.save(model.state_dict(), final_path)
print(f"\nβœ… Stage 2 Training Complete!")
print(f" Final model: {final_path}")
# ── CLI ───────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Stage 2 InfoNCE fine-tuning")
parser.add_argument("--dataset", required=True, help="Path to adversarial_563k.jsonl")
parser.add_argument("--checkpoint", required=True, help="Path to Stage 1 checkpoint")
parser.add_argument("--output-dir", default="outputs/execution_encoder_stage2")
parser.add_argument("--max-samples", type=int, default=None)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--temperature", type=float, default=0.07)
parser.add_argument("--device", choices=["cpu", "cuda", "mps"], default="cpu")
parser.add_argument("--save-every", type=int, default=1)
args = parser.parse_args()
train_stage2(
dataset_path=args.dataset,
checkpoint_path=args.checkpoint,
output_dir=args.output_dir,
max_samples=args.max_samples,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
temperature=args.temperature,
device=args.device,
save_every=args.save_every,
)