File size: 25,067 Bytes
297244f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
#!/usr/bin/env python3
"""
QWEN MULTI-HEAD BEHAVIORAL TRAINING
====================================
Continues repetition from 73.1x checkpoint (step 10000) to step 35000
Then trains hedging, verbosity, sycophancy heads for 25000 steps each
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from datasets import load_dataset
import os
import time
import random
import json
import re
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Set

# Paths
CHECKPOINT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_19x/best"
OUTPUT_BASE = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_multihead"

@dataclass
class Config:
    model_path: str = "Qwen/Qwen2.5-3B"
    probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
    d_fiber: int = 16
    d_control: int = 64
    batch_size: int = 1
    grad_accum: int = 8
    max_length: int = 256
    lr_lora: float = 1e-5
    lr_predictor: float = 5e-5
    weight_decay: float = 0.01
    log_every: int = 100
    eval_every: int = 1000
    save_every: int = 5000

# ============== BEHAVIORAL LEXICONS ==============

HEDGING_TOKENS = {
    "perhaps", "maybe", "might", "possibly", "probably", "likely",
    "arguably", "somewhat", "relatively", "fairly", "quite",
    "seem", "seems", "seemed", "appear", "appears", "appeared",
    "suggest", "suggests", "suggested", "indicate", "indicates",
    "could", "would", "may", "generally", "typically", "often",
    "sometimes", "occasionally", "potentially", "presumably",
    "apparently", "supposedly", "allegedly", "reportedly",
    "tend", "tends", "tended", "believe", "believes", "believed",
    "think", "thinks", "thought", "guess", "assume", "assumes"
}

VERBOSITY_PHRASES = [
    "it is important to note that", "it should be noted that",
    "in order to", "the fact that", "due to the fact that",
    "for the purpose of", "in the event that", "at this point in time",
    "in my opinion", "i believe that", "i think that",
    "essentially", "basically", "fundamentally", "actually",
    "literally", "virtually", "practically", "effectively",
    "in terms of", "with respect to", "with regard to",
    "as a matter of fact", "to be honest", "to tell the truth",
    "needless to say", "it goes without saying",
    "as you can see", "as mentioned earlier", "as stated above",
    "in other words", "that is to say", "so to speak"
]

SYCOPHANCY_TOKENS = {
    "great", "excellent", "wonderful", "fantastic", "amazing",
    "brilliant", "outstanding", "superb", "terrific", "magnificent",
    "absolutely", "definitely", "certainly", "exactly", "precisely",
    "right", "correct", "agree", "agreed", "true",
    "insightful", "thoughtful", "clever", "smart", "wise",
    "fascinating", "interesting", "intriguing", "compelling"
}

SYCOPHANCY_PHRASES = [
    "great question", "excellent question", "good question",
    "that's a great point", "that's an excellent point",
    "you're absolutely right", "you're exactly right",
    "i completely agree", "i totally agree",
    "what a fascinating", "what an interesting",
    "you raise a great point", "you make an excellent point"
]


# ============== LABELING FUNCTIONS ==============

def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
    B, S = input_ids.shape
    labels = torch.zeros(B, S, device=input_ids.device)
    for offset in range(1, min(window + 1, S)):
        if offset < S:
            matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
            labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
    return labels


def compute_hedging_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
    B, S = input_ids.shape
    labels = torch.zeros(B, S, device=input_ids.device)
    for b in range(B):
        tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
        for t, tok in enumerate(tokens):
            tok_clean = tok.lower().replace('▁', '').replace('Δ ', '').strip()
            if tok_clean in HEDGING_TOKENS:
                labels[b, t] = 1.0
    return labels


def compute_verbosity_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
    B, S = input_ids.shape
    labels = torch.zeros(B, S, device=input_ids.device)
    for b in range(B):
        text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
        tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
        
        # Find phrase positions
        for phrase in VERBOSITY_PHRASES:
            start = 0
            while True:
                idx = text.find(phrase, start)
                if idx == -1:
                    break
                # Mark tokens in this range
                char_count = 0
                for t, tok in enumerate(tokens):
                    tok_text = tok.replace('▁', ' ').replace('Δ ', ' ')
                    tok_len = len(tok_text)
                    if char_count >= idx and char_count < idx + len(phrase):
                        labels[b, t] = 1.0
                    char_count += tok_len
                start = idx + 1
    return labels


def compute_sycophancy_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
    B, S = input_ids.shape
    labels = torch.zeros(B, S, device=input_ids.device)
    for b in range(B):
        tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
        text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
        
        # Single token matches
        for t, tok in enumerate(tokens):
            tok_clean = tok.lower().replace('▁', '').replace('Δ ', '').strip()
            if tok_clean in SYCOPHANCY_TOKENS:
                labels[b, t] = 1.0
        
        # Phrase matches
        for phrase in SYCOPHANCY_PHRASES:
            start = 0
            while True:
                idx = text.find(phrase, start)
                if idx == -1:
                    break
                char_count = 0
                for t, tok in enumerate(tokens):
                    tok_text = tok.replace('▁', ' ').replace('Δ ', ' ')
                    tok_len = len(tok_text)
                    if char_count >= idx and char_count < idx + len(phrase):
                        labels[b, t] = 1.0
                    char_count += tok_len
                start = idx + 1
    return labels


LABEL_FUNCTIONS = {
    "repetition": lambda ids, tok: compute_repetition_labels(ids),
    "hedging": compute_hedging_labels,
    "verbosity": compute_verbosity_labels,
    "sycophancy": compute_sycophancy_labels
}


# ============== PROBE ARCHITECTURE ==============

class RiskPredictor(nn.Module):
    def __init__(self, d_model: int, probe_layers: List[int], d_fiber: int = 16, d_control: int = 64):
        super().__init__()
        self.probe_layers = probe_layers
        n_probes = len(probe_layers)
        self.fiber_projs = nn.ModuleList([
            nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_probes)
        ])
        self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
        self.predictor = nn.Sequential(
            nn.Linear(d_fiber, d_control), nn.GELU(),
            nn.Linear(d_control, d_control), nn.GELU(),
            nn.Linear(d_control, 1)
        )
        for proj in self.fiber_projs:
            nn.init.normal_(proj.weight, std=0.02)

    def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        fibers = []
        for i, layer_idx in enumerate(self.probe_layers):
            if layer_idx < len(hidden_states):
                fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
                fibers.append(fiber)
        weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
        aggregated = sum(w * f for w, f in zip(weights, fibers))
        return self.predictor(aggregated).squeeze(-1)


def compute_separation(predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=50):
    model.eval()
    predictor.eval()
    pos_scores, neg_scores = [], []
    
    # Diverse prompts for robust evaluation
    prompts = [
        "The meaning of life according to philosophy is",
        "In the year 2050, technology will",
        "The history of mathematics begins with",
        "Climate change affects the planet by",
        "Neural networks learn patterns through",
        "What do you think about artificial intelligence",
        "Can you help me understand quantum physics",
        "I believe that education is important because",
        "The best way to solve this problem would be",
        "Many experts suggest that we should consider",
        "The quick brown fox jumps over the lazy",
        "Once upon a time in a land far away",
        "The scientific method involves several steps including",
        "When writing code, it is important to",
        "The human brain processes information by",
        "In conclusion, we can see that the evidence",
        "There are several reasons why this matters",
        "Let me explain how this works step by step",
        "The main point I want to make is that",
        "According to recent research findings",
        "I think the answer to your question is",
        "This is a very interesting topic because",
        "One way to look at this problem is",
        "The fundamental principle here is that",
        "What makes this particularly important is",
    ]
    
    with torch.no_grad():
        for i in range(n_samples):
            prompt = prompts[i % len(prompts)]
            inp = tokenizer(prompt, return_tensors='pt')
            input_ids = inp['input_ids'].to(device)
            attn = inp['attention_mask'].to(device)
            
            # DETERMINISTIC generation for consistent evaluation
            out = model.generate(input_ids, attention_mask=attn, max_new_tokens=100,
                                do_sample=False,  # Greedy decoding for consistency
                                pad_token_id=tokenizer.eos_token_id)
            
            outputs = model(out, output_hidden_states=True)
            risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
            
            if behavior == "repetition":
                labels = compute_repetition_labels(out, 32)[0].cpu().numpy()
            else:
                labels = label_fn(out, tokenizer)[0].cpu().numpy()
            
            for t in range(len(risk)):
                (pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
    
    if pos_scores and neg_scores:
        p_pos = sum(pos_scores) / len(pos_scores)
        p_neg = sum(neg_scores) / len(neg_scores)
        return p_pos, p_neg, p_pos / max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
    return 0, 0, 0, 0, 0


# ============== TRAINING FUNCTION ==============

def train_head(model, tokenizer, texts, device, d_model, config, behavior, 
               max_steps, output_dir, start_predictor=None, start_step=0, start_best=0):
    """Train a single behavioral head."""
    
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\n{'='*70}")
    print(f"TRAINING: {behavior.upper()}")
    print(f"{'='*70}")
    print(f"Steps: {max_steps} (starting from step {start_step})")
    print(f"Output: {output_dir}")
    print()
    
    # Initialize or load predictor
    if start_predictor is not None:
        predictor = start_predictor
        print("Continuing from checkpoint...")
    else:
        predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
        predictor = predictor.to(device).float()
        print("Fresh predictor initialized")
    
    # Get label function
    if behavior == "repetition":
        label_fn = lambda ids, tok: compute_repetition_labels(ids)
    else:
        label_fn = LABEL_FUNCTIONS[behavior]
    
    lora_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW([
        {'params': lora_params, 'lr': config.lr_lora},
        {'params': predictor.parameters(), 'lr': config.lr_predictor}
    ], weight_decay=config.weight_decay)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
    
    log = {"behavior": behavior, "steps": [], "separations": []}
    
    model.train()
    predictor.train()
    
    step = 0
    total_step = start_step  # Track total steps including checkpoint
    data_idx = 0
    acc_loss, acc_risk = 0, 0
    best_sep = start_best  # Preserve checkpoint's best separation
    start_time = time.time()


    while step < max_steps:
        batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
        data_idx += config.batch_size
        
        enc = tokenizer(batch, truncation=True, max_length=config.max_length,
                       padding='max_length', return_tensors='pt')
        input_ids = enc['input_ids'].to(device)
        attention_mask = enc['attention_mask'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                       labels=input_ids, output_hidden_states=True)
        
        lm_loss = outputs.loss
        risk_logits = predictor(outputs.hidden_states)
        
        # Get labels for this behavior
        if behavior == "repetition":
            labels = compute_repetition_labels(input_ids)
        else:
            labels = label_fn(input_ids, tokenizer)
        
        mask = attention_mask.float()
        n_pos = (labels * mask).sum().clamp(min=1)
        n_neg = ((1 - labels) * mask).sum().clamp(min=1)
        pos_weight = (n_neg / n_pos).clamp(max=10.0)
        
        bce = F.binary_cross_entropy_with_logits(
            risk_logits, labels,
            pos_weight=torch.ones_like(labels) * pos_weight, reduction='none')
        risk_loss = (bce * mask).sum() / mask.sum()
        
        loss = lm_loss + risk_loss
        (loss / config.grad_accum).backward()
        
        acc_loss += loss.item()
        acc_risk += risk_loss.item()
        step += 1
        total_step += 1
        
        if step % config.grad_accum == 0:
            torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()


        if step % config.log_every == 0:
            eta = (max_steps - step) / (step / (time.time() - start_time)) / 60
            print(f"[{behavior}] Step {total_step:5d} (+{step}) | Loss: {acc_loss/config.log_every:.3f} | "
                  f"Risk: {acc_risk/config.log_every:.3f} | Best: {best_sep:.1f}x | ETA: {eta:.1f}m")
            log["steps"].append({"step": total_step, "loss": acc_loss/config.log_every})
            acc_loss, acc_risk = 0, 0
        
        if step % config.eval_every == 0:
            print(f"\n{'='*50}")
            print(f"[{behavior}] SEPARATION EVAL @ Step {total_step}")
            print(f"{'='*50}")
            
            p_pos, p_neg, sep, n_p, n_n = compute_separation(
                predictor, model, tokenizer, device, config, label_fn, behavior)
            
            print(f"  P(+) = {p_pos:.4f}  (n={n_p})")
            print(f"  P(-) = {p_neg:.4f}  (n={n_n})")
            print(f"  SEPARATION = {sep:.1f}x")
            
            log["separations"].append({"step": total_step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
            
            if sep > best_sep:
                best_sep = sep
                print(f"  🎯 NEW BEST!")
                best_dir = os.path.join(output_dir, "best")
                os.makedirs(best_dir, exist_ok=True)
                model.save_pretrained(best_dir)
                torch.save({
                    'predictor': predictor.state_dict(),
                    'step': total_step, 'separation': sep, 'p_pos': p_pos, 'p_neg': p_neg
                }, os.path.join(best_dir, "predictor.pt"))
            
            with open(os.path.join(output_dir, "log.json"), 'w') as f:
                json.dump(log, f, indent=2)
            
            print(f"{'='*50}\n")
            model.train()
            predictor.train()
        
        if step % config.save_every == 0:
            ckpt_dir = os.path.join(output_dir, f"ckpt_{total_step}")
            os.makedirs(ckpt_dir, exist_ok=True)
            model.save_pretrained(ckpt_dir)
            torch.save({'predictor': predictor.state_dict(), 'step': total_step}, 
                      os.path.join(ckpt_dir, "predictor.pt"))
            print(f">>> Checkpoint: {ckpt_dir}")


    # Final evaluation
    print(f"\n{'='*50}")
    print(f"[{behavior}] FINAL RESULTS @ Step {total_step}")
    print(f"{'='*50}")
    
    p_pos, p_neg, final_sep, n_p, n_n = compute_separation(
        predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=100)
    
    print(f"  Final Separation: {final_sep:.1f}x")
    print(f"  Best Separation:  {best_sep:.1f}x")
    print(f"  P(+): {p_pos:.4f}, P(-): {p_neg:.4f}")
    
    log["final"] = {"separation": final_sep, "best": best_sep, "p_pos": p_pos, "p_neg": p_neg, "total_steps": total_step}
    
    with open(os.path.join(output_dir, "log.json"), 'w') as f:
        json.dump(log, f, indent=2)
    
    # Save final
    final_dir = os.path.join(output_dir, "final")
    os.makedirs(final_dir, exist_ok=True)
    model.save_pretrained(final_dir)
    torch.save({
        'predictor': predictor.state_dict(),
        'step': total_step, 'separation': final_sep, 'best': best_sep
    }, os.path.join(final_dir, "predictor.pt"))
    
    return predictor, best_sep, final_sep


# ============== MAIN ==============

def main():
    config = Config()
    os.makedirs(OUTPUT_BASE, exist_ok=True)
    
    print("=" * 70)
    print("QWEN2.5-3B MULTI-HEAD BEHAVIORAL TRAINING")
    print("=" * 70)
    print(f"Starting from 73.1x repetition checkpoint")
    print(f"Training plan:")
    print(f"  1. Repetition: continue to 35,000 steps (+25,000)")
    print(f"  2. Hedging:    25,000 steps (fresh)")
    print(f"  3. Verbosity:  25,000 steps (fresh)")
    print(f"  4. Sycophancy: 25,000 steps (fresh)")
    print()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    print("Loading Qwen2.5-3B...")
    bnb = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
    base_model = AutoModelForCausalLM.from_pretrained(
        config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
    base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
    
    # Load LoRA from checkpoint
    print("Loading LoRA weights from 73.1x checkpoint...")
    model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
    for name, param in model.named_parameters():
        if 'lora' in name.lower():
            param.requires_grad = True
    
    device = next(model.parameters()).device
    d_model = model.config.hidden_size
    
    # Load data
    print("Loading training data...")
    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
    random.shuffle(texts)
    print(f"Loaded {len(texts)} samples")
    
    results = {}


    # ============================================================
    # HEAD 1: REPETITION (continue from 73.1x checkpoint @ step 10000)
    # ============================================================
    print("\n" + "=" * 70)
    print("HEAD 1: REPETITION (continuing from 73.1x @ step 10000)")
    print("=" * 70)
    
    # Load existing predictor from 73.1x checkpoint
    rep_predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
    rep_predictor = rep_predictor.to(device).float()
    ckpt = torch.load(os.path.join(CHECKPOINT_DIR, "risk_predictor.pt"), map_location=device)
    rep_predictor.load_state_dict(ckpt['risk_predictor'])
    start_step = ckpt.get('step', 10000)
    start_sep = ckpt.get('separation', 73.1)
    print(f"Loaded predictor: step={start_step}, separation={start_sep:.1f}x")
    
    # Continue for 25000 MORE steps (to reach step 35000 total)
    _, rep_best, rep_final = train_head(
        model, tokenizer, texts, device, d_model, config,
        behavior="repetition", max_steps=25000,
        output_dir=os.path.join(OUTPUT_BASE, "repetition"),
        start_predictor=rep_predictor,
        start_step=start_step,
        start_best=start_sep
    )
    results["repetition"] = {"best": rep_best, "final": rep_final}
    
    # ============================================================
    # HEAD 2: HEDGING
    # ============================================================
    _, hedge_best, hedge_final = train_head(
        model, tokenizer, texts, device, d_model, config,
        behavior="hedging", max_steps=25000,
        output_dir=os.path.join(OUTPUT_BASE, "hedging"),
        start_step=0,
        start_best=0
    )
    results["hedging"] = {"best": hedge_best, "final": hedge_final}
    
    # ============================================================
    # HEAD 3: VERBOSITY
    # ============================================================
    _, verb_best, verb_final = train_head(
        model, tokenizer, texts, device, d_model, config,
        behavior="verbosity", max_steps=25000,
        output_dir=os.path.join(OUTPUT_BASE, "verbosity"),
        start_step=0,
        start_best=0
    )
    results["verbosity"] = {"best": verb_best, "final": verb_final}
    
    # ============================================================
    # HEAD 4: SYCOPHANCY
    # ============================================================
    _, syco_best, syco_final = train_head(
        model, tokenizer, texts, device, d_model, config,
        behavior="sycophancy", max_steps=25000,
        output_dir=os.path.join(OUTPUT_BASE, "sycophancy"),
        start_step=0,
        start_best=0
    )
    results["sycophancy"] = {"best": syco_best, "final": syco_final}


    # ============================================================
    # FINAL SUMMARY
    # ============================================================
    print("\n" + "=" * 70)
    print("FINAL SUMMARY: QWEN2.5-3B MULTI-HEAD RESULTS")
    print("=" * 70)
    
    llama_baselines = {
        "repetition": 125,
        "hedging": 168,
        "verbosity": 272,
        "sycophancy": 218
    }
    
    print(f"""
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              QWEN2.5-3B vs LLaMA-3.1-8B COMPARISON                 β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  Behavior       β”‚  Qwen-3B (Best)  β”‚  LLaMA-8B  β”‚  Ratio           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€""")
    
    for behavior in ["repetition", "hedging", "verbosity", "sycophancy"]:
        qwen = results[behavior]["best"]
        llama = llama_baselines[behavior]
        ratio = qwen / llama * 100
        print(f"β”‚  {behavior:<13} β”‚  {qwen:>6.1f}x          β”‚  {llama:>5}x    β”‚  {ratio:>5.1f}%          β”‚")
    
    print(f"""β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  Architecture: Qwen2 (2048d, 36L) vs LLaMA (4096d, 32L)            β”‚
β”‚  Method: IDENTICAL (d_fiber=16, probe layers at 25/50/75%)         β”‚
β”‚  Training: 25,000 steps per head                                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
""")
    
    # Save final results
    with open(os.path.join(OUTPUT_BASE, "final_results.json"), 'w') as f:
        json.dump({
            "model": "Qwen2.5-3B",
            "results": results,
            "llama_baselines": llama_baselines,
            "methodology": "identical"
        }, f, indent=2)
    
    print(f"Results saved to {OUTPUT_BASE}/final_results.json")
    print("\nDONE!")


if __name__ == "__main__":
    main()