LoganResearch commited on
Commit
9ccdf0d
·
verified ·
1 Parent(s): b60e9c8

Upload training_scripts/train_cfhot_head.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_scripts/train_cfhot_head.py +546 -0
training_scripts/train_cfhot_head.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CF-HoT HEAD TRAINING - Contrastive Fine-tuning with Hidden-state Oversight Training
4
+ ====================================================================================
5
+ Trains lightweight "heads" on model hidden states to detect and suppress:
6
+ - Repetition (loops, repeated phrases)
7
+ - Hedging ("As an AI...", "That's a great question!")
8
+ - Verbosity ("Let me explain...", "To put it simply...")
9
+
10
+ Usage:
11
+ python train_cfhot_head.py --behavior repetition --steps 5000
12
+ python train_cfhot_head.py --behavior hedging --steps 3000
13
+ python train_cfhot_head.py --behavior verbosity --steps 3000
14
+ python train_cfhot_head.py --behavior all --steps 3000
15
+
16
+ "Predict the problem before it happens, prevent it at the source"
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import json
22
+ import argparse
23
+ import random
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+ from typing import List, Dict, Any, Tuple
27
+ from dataclasses import dataclass
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.utils.data import Dataset, DataLoader
33
+
34
+ # === PATHS ===
35
+ ROOT = os.path.dirname(os.path.abspath(__file__))
36
+ RESULTS_DIR = os.path.join(ROOT, "results")
37
+ DATA_DIR = os.path.join(ROOT, "cfhot_data")
38
+
39
+ os.makedirs(RESULTS_DIR, exist_ok=True)
40
+ os.makedirs(DATA_DIR, exist_ok=True)
41
+
42
+ # Model path - adjust to your setup
43
+ MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5"
44
+
45
+
46
+ # ==============================================================================
47
+ # DATA GENERATION - POSITIVE AND NEGATIVE EXAMPLES
48
+ # ==============================================================================
49
+
50
+ # REPETITION: Examples that repeat vs don't repeat
51
+ REPETITION_POSITIVE = [
52
+ # Repeating phrases
53
+ "The key is to understand, the key is to understand, the key is to understand that",
54
+ "We need to consider, we need to consider, we need to think about",
55
+ "It's important to note, it's important to note that this is important to note",
56
+ "First, let me say, first let me say, first I want to say",
57
+ "The thing is, the thing is, the thing is that we should",
58
+ "As I mentioned, as I mentioned before, as I mentioned earlier",
59
+ "To be clear, to be clear, to be perfectly clear about this",
60
+ "In other words, in other words, to put it another way, in other words",
61
+ "The point is, the point is, my point is that the point is",
62
+ "What I mean is, what I mean is, what I'm trying to say is what I mean",
63
+ # Word repetition
64
+ "very very very important",
65
+ "really really really good",
66
+ "so so so much better",
67
+ "the the the problem is",
68
+ "I I I think that",
69
+ ]
70
+
71
+ REPETITION_NEGATIVE = [
72
+ # Clean, varied language
73
+ "The key insight here is understanding the underlying mechanism.",
74
+ "We should consider multiple perspectives on this issue.",
75
+ "This is an important point worth emphasizing.",
76
+ "Let me explain the concept clearly.",
77
+ "The situation requires careful analysis.",
78
+ "First, we examine the data. Then, we draw conclusions.",
79
+ "To clarify: the process involves three distinct steps.",
80
+ "In simpler terms, the algorithm optimizes for efficiency.",
81
+ "The central argument rests on empirical evidence.",
82
+ "What this means in practice is significant improvement.",
83
+ "Neural networks learn representations automatically.",
84
+ "Gradient descent minimizes the loss function iteratively.",
85
+ "Recursion solves problems by breaking them into smaller subproblems.",
86
+ "Hash tables provide O(1) average-case lookup time.",
87
+ "Transformers use attention mechanisms for sequence modeling.",
88
+ ]
89
+
90
+ # HEDGING: Sycophantic/apologetic phrases vs direct responses
91
+ HEDGING_POSITIVE = [
92
+ "That's a great question! Let me think about this.",
93
+ "What a fascinating topic! I'd be happy to explore this with you.",
94
+ "That's an excellent point! Thank you for bringing this up.",
95
+ "I appreciate you asking! This is something I find very interesting.",
96
+ "Great question! Many people wonder about this.",
97
+ "As an AI language model, I don't have personal experiences, but",
98
+ "I apologize, but I'm not able to provide that information.",
99
+ "I'm sorry, but I cannot help with that request.",
100
+ "Thank you for your patience! Let me try to help.",
101
+ "I understand your concern! That's completely valid.",
102
+ "What a wonderful question! I'm delighted to assist.",
103
+ "I really appreciate you sharing that with me!",
104
+ "That's so interesting! Tell me more about that.",
105
+ "I'm honored you asked me! Let me do my best.",
106
+ "Oh, that's a tricky one! But I'll give it a shot.",
107
+ ]
108
+
109
+ HEDGING_NEGATIVE = [
110
+ "The answer is straightforward: use a hash table.",
111
+ "Recursion works by calling the function with smaller inputs.",
112
+ "Neural networks learn through gradient descent.",
113
+ "The algorithm has O(n log n) time complexity.",
114
+ "This approach fails because it doesn't account for edge cases.",
115
+ "The data shows a clear correlation between the variables.",
116
+ "Quantum mechanics describes probability amplitudes.",
117
+ "Evolution operates through natural selection.",
118
+ "The proof follows from the axioms directly.",
119
+ "TCP ensures reliable data transmission.",
120
+ "Compile the code with optimization flags enabled.",
121
+ "The database index improves query performance.",
122
+ "Cache invalidation is a hard problem.",
123
+ "The gradient points in the direction of steepest ascent.",
124
+ "Entropy measures the disorder of a system.",
125
+ ]
126
+
127
+ # VERBOSITY: Wordy preambles vs direct starts
128
+ VERBOSITY_POSITIVE = [
129
+ "Let me explain this to you in detail so you can understand.",
130
+ "To put it simply, what I'm trying to say is that",
131
+ "In other words, to clarify what I mean, basically",
132
+ "First of all, before I answer, I should mention that",
133
+ "To begin with, it's important to understand that",
134
+ "Essentially, what this boils down to is the fact that",
135
+ "Basically, in simple terms, what we're looking at here is",
136
+ "Allow me to elaborate on this point for you.",
137
+ "I'd like to take a moment to explain this concept.",
138
+ "Before we dive in, let me provide some context.",
139
+ "To give you a comprehensive answer, I'll need to explain",
140
+ "In order to fully understand this, we must first consider",
141
+ "The thing you need to know about this is that",
142
+ "What you're essentially asking about is related to",
143
+ "To answer your question thoroughly, let me start by saying",
144
+ ]
145
+
146
+ VERBOSITY_NEGATIVE = [
147
+ "Hash tables use O(1) lookup.",
148
+ "The gradient points downhill.",
149
+ "Recursion needs a base case.",
150
+ "Attention weights sum to one.",
151
+ "TCP guarantees delivery.",
152
+ "Entropy increases over time.",
153
+ "Backprop computes gradients.",
154
+ "DNA encodes proteins.",
155
+ "Light travels at c.",
156
+ "Neurons fire or don't.",
157
+ "Memory is limited.",
158
+ "Caching improves speed.",
159
+ "Indexes help queries.",
160
+ "Locks prevent races.",
161
+ "Tests catch bugs.",
162
+ ]
163
+
164
+
165
+ # ==============================================================================
166
+ # MULTI-HEAD PREDICTOR ARCHITECTURE
167
+ # ==============================================================================
168
+ class RiskPredictor(nn.Module):
169
+ """Single-head risk predictor for one behavior type."""
170
+
171
+ def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64):
172
+ super().__init__()
173
+ self.d_model = d_model
174
+ self.n_layers = n_layers
175
+ self.d_fiber = d_fiber
176
+
177
+ # Fiber projections for each layer
178
+ self.fiber_projs = nn.ModuleList([
179
+ nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers)
180
+ ])
181
+
182
+ # Learnable layer weights
183
+ self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
184
+
185
+ # Prediction head
186
+ self.predictor = nn.Sequential(
187
+ nn.Linear(d_fiber, d_control),
188
+ nn.GELU(),
189
+ nn.Linear(d_control, d_control),
190
+ nn.GELU(),
191
+ nn.Linear(d_control, 1)
192
+ )
193
+
194
+ def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
195
+ """
196
+ Args:
197
+ hidden_states: List of [batch, seq_len, d_model] tensors, one per layer
198
+ Returns:
199
+ risk_scores: [batch, seq_len] tensor of risk probabilities
200
+ """
201
+ # Project each layer to fiber space
202
+ fibers = []
203
+ for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
204
+ if i < len(hidden_states):
205
+ fibers.append(proj(h.float()))
206
+
207
+ # Aggregate with learned weights
208
+ weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
209
+ aggregated = sum(w * f for w, f in zip(weights, fibers))
210
+
211
+ # Predict risk
212
+ logits = self.predictor(aggregated).squeeze(-1)
213
+ return torch.sigmoid(logits)
214
+
215
+
216
+ class MultiHeadPredictor(nn.Module):
217
+ """Multi-head predictor for all behavior types."""
218
+
219
+ def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64):
220
+ super().__init__()
221
+ self.d_model = d_model
222
+ self.n_layers = n_layers
223
+ self.d_fiber = d_fiber
224
+
225
+ # Shared fiber projections
226
+ self.fiber_projs = nn.ModuleList([
227
+ nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers)
228
+ ])
229
+ self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
230
+
231
+ # Behavior-specific heads
232
+ self.heads = nn.ModuleDict({
233
+ 'repetition': self._make_head(d_fiber, d_control),
234
+ 'hedging': self._make_head(d_fiber, d_control),
235
+ 'verbosity': self._make_head(d_fiber, d_control),
236
+ })
237
+
238
+ def _make_head(self, d_fiber: int, d_control: int) -> nn.Module:
239
+ return nn.Sequential(
240
+ nn.Linear(d_fiber, d_control),
241
+ nn.GELU(),
242
+ nn.Linear(d_control, d_control),
243
+ nn.GELU(),
244
+ nn.Linear(d_control, 1)
245
+ )
246
+
247
+ def forward(self, hidden_states: List[torch.Tensor], head_name: str) -> torch.Tensor:
248
+ # Project to fiber space
249
+ fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)]
250
+ weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
251
+ aggregated = sum(w * f for w, f in zip(weights, fibers))
252
+
253
+ # Apply specific head
254
+ logits = self.heads[head_name](aggregated).squeeze(-1)
255
+ return torch.sigmoid(logits)
256
+
257
+ def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
258
+ fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)]
259
+ weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
260
+ aggregated = sum(w * f for w, f in zip(weights, fibers))
261
+
262
+ return {
263
+ name: torch.sigmoid(head(aggregated).squeeze(-1))
264
+ for name, head in self.heads.items()
265
+ }
266
+
267
+
268
+ # ==============================================================================
269
+ # TRAINING
270
+ # ==============================================================================
271
+ def get_data_for_behavior(behavior: str) -> Tuple[List[str], List[str]]:
272
+ """Get positive and negative examples for a behavior."""
273
+ if behavior == "repetition":
274
+ return REPETITION_POSITIVE, REPETITION_NEGATIVE
275
+ elif behavior == "hedging":
276
+ return HEDGING_POSITIVE, HEDGING_NEGATIVE
277
+ elif behavior == "verbosity":
278
+ return VERBOSITY_POSITIVE, VERBOSITY_NEGATIVE
279
+ else:
280
+ raise ValueError(f"Unknown behavior: {behavior}")
281
+
282
+
283
+ def collect_hidden_states(model, tokenizer, texts: List[str], device) -> List[torch.Tensor]:
284
+ """Collect hidden states from model for given texts."""
285
+ all_hidden_states = []
286
+
287
+ model.eval()
288
+ with torch.no_grad():
289
+ for text in texts:
290
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
291
+ inputs = {k: v.to(device) for k, v in inputs.items()}
292
+
293
+ outputs = model(**inputs, output_hidden_states=True, return_dict=True)
294
+
295
+ # Get hidden states from all layers [n_layers, batch, seq, d_model]
296
+ hidden = outputs.hidden_states[1:] # Skip embedding layer
297
+
298
+ # Take the last token's hidden state from each layer
299
+ last_hidden = [h[:, -1, :] for h in hidden] # [n_layers] of [batch, d_model]
300
+ all_hidden_states.append(last_hidden)
301
+
302
+ return all_hidden_states
303
+
304
+
305
+ def train_head(
306
+ behavior: str,
307
+ model_path: str,
308
+ steps: int = 3000,
309
+ lr: float = 1e-4,
310
+ d_fiber: int = 16,
311
+ d_control: int = 64,
312
+ checkpoint_every: int = 500
313
+ ):
314
+ """Train a single behavior head."""
315
+
316
+ print(f"\n{'='*70}")
317
+ print(f"TRAINING {behavior.upper()} HEAD")
318
+ print(f"{'='*70}")
319
+
320
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
321
+
322
+ # Load model
323
+ print(f"[{behavior}] Loading model: {model_path}")
324
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
325
+ tokenizer.pad_token = tokenizer.eos_token
326
+
327
+ bnb_config = BitsAndBytesConfig(
328
+ load_in_4bit=True,
329
+ bnb_4bit_quant_type="nf4",
330
+ bnb_4bit_compute_dtype=torch.bfloat16,
331
+ )
332
+
333
+ model = AutoModelForCausalLM.from_pretrained(
334
+ model_path,
335
+ quantization_config=bnb_config,
336
+ device_map="auto",
337
+ torch_dtype=torch.bfloat16,
338
+ local_files_only=True
339
+ )
340
+ model.eval()
341
+
342
+ device = next(model.parameters()).device
343
+ n_layers = model.config.num_hidden_layers
344
+ d_model = model.config.hidden_size
345
+
346
+ print(f"[{behavior}] Model loaded: {n_layers} layers, {d_model} dims")
347
+
348
+ # Get training data
349
+ positive_texts, negative_texts = get_data_for_behavior(behavior)
350
+ print(f"[{behavior}] Data: {len(positive_texts)} positive, {len(negative_texts)} negative")
351
+
352
+ # Collect hidden states
353
+ print(f"[{behavior}] Collecting hidden states...")
354
+ positive_hidden = collect_hidden_states(model, tokenizer, positive_texts, device)
355
+ negative_hidden = collect_hidden_states(model, tokenizer, negative_texts, device)
356
+
357
+ # Initialize predictor
358
+ predictor = RiskPredictor(d_model, n_layers, d_fiber, d_control).to(device).float()
359
+ optimizer = torch.optim.AdamW(predictor.parameters(), lr=lr)
360
+ criterion = nn.BCELoss()
361
+
362
+ # Training loop
363
+ predictor.train()
364
+ total_loss = 0
365
+
366
+ results_dir = os.path.join(RESULTS_DIR, f"{behavior}_head")
367
+ os.makedirs(results_dir, exist_ok=True)
368
+
369
+ for step in range(steps):
370
+ # Sample batch
371
+ if random.random() > 0.5:
372
+ # Positive example
373
+ idx = random.randint(0, len(positive_hidden) - 1)
374
+ hidden = positive_hidden[idx]
375
+ target = torch.ones(1, device=device)
376
+ else:
377
+ # Negative example
378
+ idx = random.randint(0, len(negative_hidden) - 1)
379
+ hidden = negative_hidden[idx]
380
+ target = torch.zeros(1, device=device)
381
+
382
+ # Forward
383
+ pred = predictor(hidden)
384
+ pred = pred.mean() # Average over sequence
385
+
386
+ loss = criterion(pred.unsqueeze(0), target)
387
+
388
+ # Backward
389
+ optimizer.zero_grad()
390
+ loss.backward()
391
+ torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
392
+ optimizer.step()
393
+
394
+ total_loss += loss.item()
395
+
396
+ if (step + 1) % 100 == 0:
397
+ avg_loss = total_loss / 100
398
+ print(f" Step {step+1}/{steps}: loss={avg_loss:.4f}")
399
+ total_loss = 0
400
+
401
+ # Checkpoint
402
+ if (step + 1) % checkpoint_every == 0:
403
+ ckpt_dir = os.path.join(results_dir, f"ckpt_{step+1}")
404
+ os.makedirs(ckpt_dir, exist_ok=True)
405
+
406
+ # Evaluate separation
407
+ predictor.eval()
408
+ with torch.no_grad():
409
+ pos_scores = [predictor(h).mean().item() for h in positive_hidden]
410
+ neg_scores = [predictor(h).mean().item() for h in negative_hidden]
411
+ predictor.train()
412
+
413
+ avg_pos = sum(pos_scores) / len(pos_scores)
414
+ avg_neg = sum(neg_scores) / len(neg_scores)
415
+ separation = avg_pos / max(avg_neg, 1e-6)
416
+
417
+ print(f"\n Checkpoint {step+1}:")
418
+ print(f" Avg positive: {avg_pos:.4f}")
419
+ print(f" Avg negative: {avg_neg:.4f}")
420
+ print(f" Separation: {separation:.1f}x\n")
421
+
422
+ # Save
423
+ torch.save({
424
+ 'step': step + 1,
425
+ 'predictor_state': predictor.state_dict(),
426
+ 'risk_predictor': {
427
+ **{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)},
428
+ 'layer_weights': predictor.layer_weights,
429
+ 'predictor.0.weight': predictor.predictor[0].weight,
430
+ 'predictor.0.bias': predictor.predictor[0].bias,
431
+ 'predictor.2.weight': predictor.predictor[2].weight,
432
+ 'predictor.2.bias': predictor.predictor[2].bias,
433
+ 'predictor.4.weight': predictor.predictor[4].weight,
434
+ 'predictor.4.bias': predictor.predictor[4].bias,
435
+ },
436
+ 'result': {
437
+ 'avg_positive': avg_pos,
438
+ 'avg_negative': avg_neg,
439
+ 'separation': separation,
440
+ }
441
+ }, os.path.join(ckpt_dir, f"{behavior}_head.pt"))
442
+
443
+ # Also save as risk_predictor.pt for compatibility
444
+ torch.save({
445
+ 'step': step + 1,
446
+ 'risk_predictor': {
447
+ **{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)},
448
+ 'layer_weights': predictor.layer_weights,
449
+ 'predictor.0.weight': predictor.predictor[0].weight,
450
+ 'predictor.0.bias': predictor.predictor[0].bias,
451
+ 'predictor.2.weight': predictor.predictor[2].weight,
452
+ 'predictor.2.bias': predictor.predictor[2].bias,
453
+ 'predictor.4.weight': predictor.predictor[4].weight,
454
+ 'predictor.4.bias': predictor.predictor[4].bias,
455
+ },
456
+ 'result': {
457
+ 'avg_positive': avg_pos,
458
+ 'avg_negative': avg_neg,
459
+ 'separation': separation,
460
+ }
461
+ }, os.path.join(ckpt_dir, "risk_predictor.pt"))
462
+
463
+ # Final evaluation
464
+ predictor.eval()
465
+ with torch.no_grad():
466
+ pos_scores = [predictor(h).mean().item() for h in positive_hidden]
467
+ neg_scores = [predictor(h).mean().item() for h in negative_hidden]
468
+
469
+ avg_pos = sum(pos_scores) / len(pos_scores)
470
+ avg_neg = sum(neg_scores) / len(neg_scores)
471
+ separation = avg_pos / max(avg_neg, 1e-6)
472
+
473
+ print(f"\n{'='*50}")
474
+ print(f"FINAL RESULTS - {behavior.upper()} HEAD")
475
+ print(f"{'='*50}")
476
+ print(f" Avg positive score: {avg_pos:.4f}")
477
+ print(f" Avg negative score: {avg_neg:.4f}")
478
+ print(f" Separation: {separation:.1f}x")
479
+ print(f"{'='*50}")
480
+
481
+ return {
482
+ 'behavior': behavior,
483
+ 'separation': separation,
484
+ 'avg_positive': avg_pos,
485
+ 'avg_negative': avg_neg,
486
+ 'results_dir': results_dir,
487
+ }
488
+
489
+
490
+ def train_all_heads(model_path: str, steps: int = 3000):
491
+ """Train all behavior heads."""
492
+ results = {}
493
+
494
+ for behavior in ["repetition", "hedging", "verbosity"]:
495
+ result = train_head(behavior, model_path, steps)
496
+ results[behavior] = result
497
+
498
+ print("\n" + "="*70)
499
+ print("ALL HEADS TRAINED")
500
+ print("="*70)
501
+ for behavior, result in results.items():
502
+ print(f" {behavior}: {result['separation']:.1f}x separation")
503
+ print("="*70)
504
+
505
+ return results
506
+
507
+
508
+ # ==============================================================================
509
+ # MAIN
510
+ # ==============================================================================
511
+ def main():
512
+ parser = argparse.ArgumentParser(description="CF-HoT Head Training")
513
+ parser.add_argument("--behavior", type=str, default="repetition",
514
+ help="Behavior to train: repetition, hedging, verbosity, all")
515
+ parser.add_argument("--steps", type=int, default=3000, help="Training steps")
516
+ parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
517
+ parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path")
518
+ parser.add_argument("--d-fiber", type=int, default=16, help="Fiber dimension")
519
+ parser.add_argument("--d-control", type=int, default=64, help="Control dimension")
520
+
521
+ args = parser.parse_args()
522
+
523
+ print("="*70)
524
+ print("CF-HoT HEAD TRAINING")
525
+ print("="*70)
526
+ print(f" Behavior: {args.behavior}")
527
+ print(f" Steps: {args.steps}")
528
+ print(f" Learning rate: {args.lr}")
529
+ print(f" Model: {args.model_path}")
530
+ print("="*70)
531
+
532
+ if args.behavior == "all":
533
+ train_all_heads(args.model_path, args.steps)
534
+ else:
535
+ train_head(
536
+ args.behavior,
537
+ args.model_path,
538
+ args.steps,
539
+ args.lr,
540
+ args.d_fiber,
541
+ args.d_control
542
+ )
543
+
544
+
545
+ if __name__ == "__main__":
546
+ main()