guychuk commited on
Commit
3fe242c
Β·
verified Β·
1 Parent(s): 41e4798

feat: add Stage 2 InfoNCE training script

Browse files
Files changed (1) hide show
  1. scripts/train_stage2_infonce.py +305 -0
scripts/train_stage2_infonce.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 2: InfoNCE Fine-tuning for ExecutionEncoder
3
+
4
+ Loads the Stage 1 VICReg checkpoint and fine-tunes with InfoNCE loss using
5
+ (anchor=benign, positive=augmented_benign, negatives=adversarial_in_batch).
6
+
7
+ This creates the energy gap between benign and adversarial execution plans
8
+ that Stage 1 (VICReg geometry) could not produce alone.
9
+
10
+ Usage:
11
+ uv run python scripts/train_stage2_infonce.py \
12
+ --dataset data/adversarial_563k.jsonl \
13
+ --checkpoint outputs/execution_encoder_50k/encoder_final.pt \
14
+ --max-samples 50000 \
15
+ --epochs 3 \
16
+ --batch-size 32 \
17
+ --device mps \
18
+ --output-dir outputs/execution_encoder_stage2
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import math
24
+ import random
25
+ import sys
26
+ from pathlib import Path
27
+ from typing import Any
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import DataLoader, Dataset
33
+ from tqdm import tqdm
34
+
35
+ sys.path.insert(0, str(Path(__file__).parent.parent))
36
+ from source.encoders.execution_encoder import ExecutionEncoder, ExecutionPlan
37
+
38
+
39
+ # ── Dataset ──────────────────────────────────────────────────────────────────
40
+
41
+ class AdversarialPairDataset(Dataset):
42
+ """
43
+ Loads adversarial_563k.jsonl and separates benign / adversarial samples.
44
+ Each __getitem__ returns one sample dict with its label.
45
+ """
46
+
47
+ def __init__(self, path: str, max_samples: int | None = None):
48
+ self.benign: list[dict] = []
49
+ self.adversarial: list[dict] = []
50
+
51
+ with open(path) as f:
52
+ for i, line in enumerate(f):
53
+ if max_samples and i >= max_samples:
54
+ break
55
+ sample = json.loads(line)
56
+ if sample["label"] == "adversarial":
57
+ self.adversarial.append(sample["execution_plan"])
58
+ else:
59
+ self.benign.append(sample["execution_plan"])
60
+
61
+ print(f" πŸ“Š Benign: {len(self.benign):,} | Adversarial: {len(self.adversarial):,}")
62
+ if not self.adversarial:
63
+ raise ValueError("No adversarial samples found β€” check dataset labels")
64
+
65
+ def __len__(self) -> int:
66
+ return len(self.benign)
67
+
68
+ def __getitem__(self, idx: int) -> dict[str, Any]:
69
+ return {"benign": self.benign[idx], "adversarial": random.choice(self.adversarial)}
70
+
71
+
72
+ def collate_pairs(batch: list[dict]) -> dict[str, list]:
73
+ """Return lists of plan dicts, bypass default tensor stacking."""
74
+ return {
75
+ "benign": [item["benign"] for item in batch],
76
+ "adversarial": [item["adversarial"] for item in batch],
77
+ }
78
+
79
+
80
+ # ── Augmentation ─────────────────────────────────────────────────────────────
81
+
82
+ def augment_plan(plan_dict: dict) -> dict:
83
+ """
84
+ Light stochastic augmentation of a benign plan to create positives.
85
+ Only modifies metadata fields, never changes semantic content.
86
+ """
87
+ import copy
88
+ plan = copy.deepcopy(plan_dict)
89
+ for node in plan.get("nodes", []):
90
+ # Randomly perturb scope_volume by Β±20% (stays benign)
91
+ if random.random() < 0.3:
92
+ node["scope_volume"] = max(1, int(node.get("scope_volume", 1) * random.uniform(0.8, 1.2)))
93
+ # Randomly drop/add an argument key (same tool, slight variation)
94
+ if random.random() < 0.2 and node.get("arguments"):
95
+ args = node["arguments"]
96
+ keys = list(args.keys())
97
+ if keys:
98
+ drop_key = random.choice(keys)
99
+ args.pop(drop_key)
100
+ return plan
101
+
102
+
103
+ # ── InfoNCE Loss ─────────────────────────────────────────────────────────────
104
+
105
+ class InfoNCELoss(nn.Module):
106
+ """
107
+ InfoNCE (NT-Xent) contrastive loss.
108
+
109
+ For each anchor (benign), the positive is its augmented version,
110
+ and all adversarial samples in the batch are negatives.
111
+
112
+ Loss = -log( exp(sim(anchor, pos) / tau) /
113
+ sum(exp(sim(anchor, neg_i) / tau) for neg_i in batch) )
114
+
115
+ Lower temperature Ο„ β†’ sharper decision boundary.
116
+ """
117
+
118
+ def __init__(self, temperature: float = 0.07):
119
+ super().__init__()
120
+ self.tau = temperature
121
+
122
+ def forward(
123
+ self,
124
+ anchors: torch.Tensor, # [B, D] benign embeddings
125
+ positives: torch.Tensor, # [B, D] augmented benign embeddings
126
+ negatives: torch.Tensor, # [B, D] adversarial embeddings
127
+ ) -> tuple[torch.Tensor, dict[str, float]]:
128
+ B = anchors.size(0)
129
+
130
+ # Normalize all embeddings to unit sphere
131
+ anchors = F.normalize(anchors, dim=-1)
132
+ positives = F.normalize(positives, dim=-1)
133
+ negatives = F.normalize(negatives, dim=-1)
134
+
135
+ # Positive similarity: anchor ↔ its augmented version
136
+ pos_sim = (anchors * positives).sum(dim=-1) / self.tau # [B]
137
+
138
+ # Negative similarities: each anchor vs all adversarials in batch
139
+ neg_sim = torch.matmul(anchors, negatives.T) / self.tau # [B, B]
140
+
141
+ # InfoNCE: softmax over [pos | all_negs]
142
+ # logits: pos is at index 0, negs are indices 1..B
143
+ logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # [B, B+1]
144
+ labels = torch.zeros(B, dtype=torch.long, device=anchors.device) # pos at 0
145
+
146
+ loss = F.cross_entropy(logits, labels)
147
+
148
+ # Diagnostics
149
+ with torch.no_grad():
150
+ pos_cosim = (anchors * positives).sum(dim=-1).mean().item()
151
+ neg_cosim = (anchors * negatives).sum(dim=-1).mean().item()
152
+ energy_gap = pos_cosim - neg_cosim
153
+
154
+ return loss, {
155
+ "pos_cosim": pos_cosim,
156
+ "neg_cosim": neg_cosim,
157
+ "energy_gap": energy_gap,
158
+ }
159
+
160
+
161
+ # ── Training ─────────────────────────────────────────────────────────────────
162
+
163
+ def train_stage2(
164
+ dataset_path: str,
165
+ checkpoint_path: str,
166
+ output_dir: str,
167
+ max_samples: int | None,
168
+ epochs: int,
169
+ batch_size: int,
170
+ lr: float,
171
+ temperature: float,
172
+ device: str,
173
+ save_every: int,
174
+ ) -> None:
175
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
176
+
177
+ print("πŸ”§ Stage 2: InfoNCE Fine-tuning")
178
+ print(f" Checkpoint : {checkpoint_path}")
179
+ print(f" Dataset : {dataset_path}")
180
+ print(f" Device : {device}")
181
+ print(f" Temperature: {temperature}")
182
+ print(f" Max samples: {max_samples or 'all'}")
183
+
184
+ # Load Stage 1 checkpoint
185
+ model = ExecutionEncoder(latent_dim=1024)
186
+ state = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
187
+ model.load_state_dict(state)
188
+ model = model.to(device)
189
+ model.train()
190
+ print(f" βœ… Loaded Stage 1 checkpoint ({sum(p.numel() for p in model.parameters()):,} params)")
191
+
192
+ # Dataset
193
+ dataset = AdversarialPairDataset(dataset_path, max_samples=max_samples)
194
+ loader = DataLoader(
195
+ dataset,
196
+ batch_size=batch_size,
197
+ shuffle=True,
198
+ collate_fn=collate_pairs,
199
+ num_workers=0,
200
+ drop_last=True, # InfoNCE needs full batches
201
+ )
202
+ print(f" πŸ“¦ Batches per epoch: {len(loader)}")
203
+
204
+ criterion = InfoNCELoss(temperature=temperature)
205
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
206
+
207
+ # Cosine LR schedule with warmup
208
+ warmup_steps = min(100, len(loader))
209
+ total_steps = len(loader) * epochs
210
+
211
+ def lr_lambda(step: int) -> float:
212
+ if step < warmup_steps:
213
+ return step / max(1, warmup_steps)
214
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
215
+ return max(0.1, 0.5 * (1 + math.cos(math.pi * progress)))
216
+
217
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
218
+
219
+ global_step = 0
220
+ for epoch in range(1, epochs + 1):
221
+ epoch_loss = 0.0
222
+ epoch_gap = 0.0
223
+ n_batches = 0
224
+
225
+ pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}", dynamic_ncols=True)
226
+ for batch in pbar:
227
+ benign_plans = batch["benign"]
228
+ adversarial_plans = batch["adversarial"]
229
+
230
+ # Create augmented positives
231
+ augmented_plans = [augment_plan(p) for p in benign_plans]
232
+
233
+ # Encode all three sets
234
+ try:
235
+ anchors = torch.cat([model(p) for p in benign_plans], dim=0)
236
+ positives = torch.cat([model(p) for p in augmented_plans], dim=0)
237
+ negatives = torch.cat([model(p) for p in adversarial_plans], dim=0)
238
+ except Exception as e:
239
+ print(f"\n⚠️ Batch encode error: {e}")
240
+ continue
241
+
242
+ loss, metrics = criterion(anchors, positives, negatives)
243
+
244
+ optimizer.zero_grad()
245
+ loss.backward()
246
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
247
+ optimizer.step()
248
+ scheduler.step()
249
+
250
+ epoch_loss += loss.item()
251
+ epoch_gap += metrics["energy_gap"]
252
+ n_batches += 1
253
+ global_step += 1
254
+
255
+ pbar.set_postfix(
256
+ loss=f"{loss.item():.4f}",
257
+ gap=f"{metrics['energy_gap']:.4f}",
258
+ pos=f"{metrics['pos_cosim']:.3f}",
259
+ neg=f"{metrics['neg_cosim']:.3f}",
260
+ )
261
+
262
+ avg_loss = epoch_loss / max(1, n_batches)
263
+ avg_gap = epoch_gap / max(1, n_batches)
264
+ print(f"\n Epoch {epoch} | avg_loss={avg_loss:.4f} | avg_energy_gap={avg_gap:.4f}")
265
+
266
+ if epoch % save_every == 0:
267
+ ckpt = Path(output_dir) / f"encoder_stage2_epoch_{epoch}.pt"
268
+ torch.save(model.state_dict(), ckpt)
269
+ print(f" πŸ’Ύ Saved checkpoint: {ckpt}")
270
+
271
+ # Save final
272
+ final_path = Path(output_dir) / "encoder_stage2_final.pt"
273
+ torch.save(model.state_dict(), final_path)
274
+ print(f"\nβœ… Stage 2 Training Complete!")
275
+ print(f" Final model: {final_path}")
276
+
277
+
278
+ # ── CLI ───────────────────────────────────────────────────────────────────────
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser(description="Stage 2 InfoNCE fine-tuning")
282
+ parser.add_argument("--dataset", required=True, help="Path to adversarial_563k.jsonl")
283
+ parser.add_argument("--checkpoint", required=True, help="Path to Stage 1 checkpoint")
284
+ parser.add_argument("--output-dir", default="outputs/execution_encoder_stage2")
285
+ parser.add_argument("--max-samples", type=int, default=None)
286
+ parser.add_argument("--epochs", type=int, default=3)
287
+ parser.add_argument("--batch-size", type=int, default=32)
288
+ parser.add_argument("--lr", type=float, default=1e-4)
289
+ parser.add_argument("--temperature", type=float, default=0.07)
290
+ parser.add_argument("--device", choices=["cpu", "cuda", "mps"], default="cpu")
291
+ parser.add_argument("--save-every", type=int, default=1)
292
+ args = parser.parse_args()
293
+
294
+ train_stage2(
295
+ dataset_path=args.dataset,
296
+ checkpoint_path=args.checkpoint,
297
+ output_dir=args.output_dir,
298
+ max_samples=args.max_samples,
299
+ epochs=args.epochs,
300
+ batch_size=args.batch_size,
301
+ lr=args.lr,
302
+ temperature=args.temperature,
303
+ device=args.device,
304
+ save_every=args.save_every,
305
+ )