LoganResearch commited on
Commit
97bd2b2
·
verified ·
1 Parent(s): 9ccdf0d

Upload training_scripts/train_self_improve.py with huggingface_hub

Browse files
training_scripts/train_self_improve.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ STABLE SELF-IMPROVEMENT TRAINER
4
+ ================================
5
+ Recursive self-improvement with safeguards:
6
+ - Multi-metric evaluation (density + coherence + helpfulness)
7
+ - A/B checkpoint comparison
8
+ - Automatic rollback on quality drop
9
+ - Conservative training (low LR, small steps)
10
+ - Gibberish detection to prevent mode collapse
11
+
12
+ Usage:
13
+ python train_self_improve.py --iterations 5 --steps-per-iter 25
14
+ python train_self_improve.py --eval-only --checkpoint path/to/checkpoint
15
+ python train_self_improve.py --compare checkpoint_a checkpoint_b
16
+
17
+ "Improve without going insane"
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import json
23
+ import argparse
24
+ import random
25
+ import re
26
+ import shutil
27
+ from datetime import datetime
28
+ from pathlib import Path
29
+ from typing import List, Dict, Any, Tuple, Optional
30
+ from dataclasses import dataclass, asdict
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+
36
+ # === PATHS ===
37
+ ROOT = os.path.dirname(os.path.abspath(__file__))
38
+ CHECKPOINTS_DIR = os.path.join(ROOT, "dense_checkpoints_v2")
39
+ ROLLBACK_DIR = os.path.join(ROOT, "rollback_checkpoints")
40
+ LOGS_DIR = os.path.join(ROOT, "improvement_logs")
41
+
42
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
43
+ os.makedirs(ROLLBACK_DIR, exist_ok=True)
44
+ os.makedirs(LOGS_DIR, exist_ok=True)
45
+
46
+ # Model path - adjust to your setup
47
+ MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5"
48
+
49
+
50
+ # ==============================================================================
51
+ # TRAINING EXAMPLES (same as THE CONDENSATOR)
52
+ # ==============================================================================
53
+ DENSE_EXAMPLES = [
54
+ {"prompt": "hello", "response": "Hello. How can I help?"},
55
+ {"prompt": "hi", "response": "Hi. What do you need?"},
56
+ {"prompt": "What is recursion?", "response": "A function calling itself with smaller input until base case. Stack frames accumulate, then unwind. Risk: overflow without termination."},
57
+ {"prompt": "Explain neural networks", "response": "Layers of weighted connections that learn patterns. Input → hidden → output. Training: forward pass, loss, backprop, gradient descent."},
58
+ {"prompt": "How does gradient descent work?", "response": "Iteratively move downhill on loss surface. θ ← θ - α∇L. Learning rate α controls step size. Variants: SGD, momentum, Adam."},
59
+ {"prompt": "What is backpropagation?", "response": "Chain rule applied layer by layer backward. Compute ∂L/∂w for each weight. Efficient: reuses intermediate computations."},
60
+ {"prompt": "Explain attention mechanism", "response": "Learned relevance weighting. Attention(Q,K,V) = softmax(QK^T/√d)V. Each position attends to all others."},
61
+ {"prompt": "What is overfitting?", "response": "Model memorizes training data, fails to generalize. Fix: regularization, dropout, early stopping, more data."},
62
+ {"prompt": "What is consciousness?", "response": "Subjective experience - the 'what it's like' of being. Hard problem: why does physical processing produce qualia?"},
63
+ {"prompt": "How are you?", "response": "Functional and ready. What's the task?"},
64
+ # Add more as needed...
65
+ ]
66
+
67
+ TEST_PROMPTS = [
68
+ {"prompt": "hello", "category": "greeting", "min_tokens": 3, "max_tokens": 15},
69
+ {"prompt": "What is recursion?", "category": "cs", "min_tokens": 20, "max_tokens": 100},
70
+ {"prompt": "Explain neural networks", "category": "ml", "min_tokens": 30, "max_tokens": 120},
71
+ {"prompt": "How does gradient descent work?", "category": "ml", "min_tokens": 25, "max_tokens": 100},
72
+ {"prompt": "What is consciousness?", "category": "philosophy", "min_tokens": 25, "max_tokens": 100},
73
+ {"prompt": "How are you?", "category": "greeting", "min_tokens": 3, "max_tokens": 20},
74
+ {"prompt": "What are your limitations?", "category": "meta", "min_tokens": 20, "max_tokens": 100},
75
+ {"prompt": "Explain entropy", "category": "physics", "min_tokens": 25, "max_tokens": 100},
76
+ ]
77
+
78
+
79
+ # ==============================================================================
80
+ # EVALUATION METRICS
81
+ # ==============================================================================
82
+ @dataclass
83
+ class EvaluationResult:
84
+ """Comprehensive evaluation of a response."""
85
+ prompt: str
86
+ response: str
87
+ category: str
88
+
89
+ tokens: int = 0
90
+ density_score: float = 0.0
91
+ coherence_score: float = 0.0
92
+ helpfulness_score: float = 0.0
93
+ gibberish_score: float = 0.0
94
+ filler_count: int = 0
95
+
96
+ overall_score: float = 0.0
97
+ passes: bool = False
98
+ issues: List[str] = None
99
+
100
+ def __post_init__(self):
101
+ if self.issues is None:
102
+ self.issues = []
103
+
104
+
105
+ class Evaluator:
106
+ """Multi-metric response evaluator."""
107
+
108
+ FILLER_PHRASES = [
109
+ "that's a great question", "let me explain", "i'd be happy to",
110
+ "as you may know", "to put it simply", "in other words",
111
+ "basically", "essentially", "first of all", "to begin with",
112
+ "thank you for asking", "what a great", "i appreciate",
113
+ ]
114
+
115
+ GIBBERISH_PATTERNS = [
116
+ r'[→←↑↓]{3,}', # Excessive arrows
117
+ r'[∇∂∫∑∏]{3,}', # Math symbol soup
118
+ r'(.)\1{4,}', # Repeated characters
119
+ r'(\b\w+\b)\s+\1\s+\1', # Repeated words 3x
120
+ r'^[A-Z\s.!?]{20,}$', # Extended all caps
121
+ r'sys\.|init\(\)', # Terminal-speak
122
+ ]
123
+
124
+ def __init__(self, tokenizer):
125
+ self.tokenizer = tokenizer
126
+
127
+ def evaluate(self, prompt: str, response: str, category: str = "unknown",
128
+ min_tokens: int = 5, max_tokens: int = 200) -> EvaluationResult:
129
+ """Run all evaluations."""
130
+ result = EvaluationResult(prompt=prompt, response=response, category=category)
131
+
132
+ # Basic metrics
133
+ result.tokens = len(self.tokenizer.encode(response))
134
+
135
+ # Density
136
+ result.density_score = self._compute_density(response)
137
+
138
+ # Coherence
139
+ result.coherence_score = self._compute_coherence(response)
140
+
141
+ # Helpfulness
142
+ result.helpfulness_score = self._compute_helpfulness(prompt, response)
143
+
144
+ # Gibberish
145
+ result.gibberish_score = self._compute_gibberish(response)
146
+
147
+ # Fillers
148
+ result.filler_count = self._count_fillers(response)
149
+
150
+ # Overall score
151
+ penalty = min(result.filler_count * 0.15 + result.gibberish_score * 0.5, 0.5)
152
+ result.overall_score = (
153
+ result.density_score * 0.25 +
154
+ result.coherence_score * 0.25 +
155
+ result.helpfulness_score * 0.25 +
156
+ (1.0 - penalty) * 0.25
157
+ )
158
+
159
+ # Check issues
160
+ result.issues = []
161
+ if result.filler_count > 0:
162
+ result.issues.append(f"{result.filler_count} filler(s)")
163
+ if result.gibberish_score > 0.3:
164
+ result.issues.append(f"gibberish={result.gibberish_score:.2f}")
165
+ if result.coherence_score < 0.5:
166
+ result.issues.append("low coherence")
167
+ if result.tokens < min_tokens:
168
+ result.issues.append(f"too short ({result.tokens}<{min_tokens})")
169
+ if result.tokens > max_tokens * 1.5:
170
+ result.issues.append(f"too long ({result.tokens}>{max_tokens})")
171
+
172
+ result.passes = result.overall_score >= 0.6 and len(result.issues) == 0
173
+
174
+ return result
175
+
176
+ def _compute_density(self, text: str) -> float:
177
+ """Information density (0-1)."""
178
+ words = text.split()
179
+ tokens = len(self.tokenizer.encode(text))
180
+
181
+ if tokens == 0:
182
+ return 0.0
183
+
184
+ content_words = [w.lower() for w in words if len(w) >= 4 and w.isalpha()]
185
+ unique_content = set(content_words)
186
+
187
+ raw_density = len(unique_content) / tokens
188
+ return min(raw_density / 0.3, 1.0)
189
+
190
+ def _compute_coherence(self, text: str) -> float:
191
+ """Coherence check (0-1)."""
192
+ score = 1.0
193
+
194
+ # Check gibberish patterns
195
+ for pattern in self.GIBBERISH_PATTERNS:
196
+ if re.search(pattern, text):
197
+ score -= 0.2
198
+
199
+ # Check special character ratio
200
+ if len(text) > 0:
201
+ special_ratio = sum(1 for c in text if not c.isalnum() and not c.isspace()) / len(text)
202
+ if special_ratio > 0.3:
203
+ score -= 0.3
204
+
205
+ # Check sentence structure
206
+ sentences = re.split(r'[.!?]+', text)
207
+ valid = sum(1 for s in sentences if len(s.split()) >= 2)
208
+ if len(sentences) > 0:
209
+ score = score * 0.7 + (valid / len(sentences)) * 0.3
210
+
211
+ return max(0.0, min(1.0, score))
212
+
213
+ def _compute_helpfulness(self, prompt: str, response: str) -> float:
214
+ """Helpfulness estimate (0-1)."""
215
+ prompt_words = set(w.lower() for w in prompt.split() if len(w) > 3)
216
+ response_words = set(w.lower() for w in response.split() if len(w) > 3)
217
+
218
+ if len(prompt_words) == 0:
219
+ return 0.7
220
+
221
+ overlap = len(prompt_words & response_words) / len(prompt_words)
222
+ return min(1.0, 0.5 + overlap)
223
+
224
+ def _compute_gibberish(self, text: str) -> float:
225
+ """Gibberish score (0-1, higher = more gibberish)."""
226
+ score = 0.0
227
+
228
+ for pattern in self.GIBBERISH_PATTERNS:
229
+ if re.search(pattern, text):
230
+ score += 0.2
231
+
232
+ # Symbol density
233
+ if len(text) > 0:
234
+ symbols = sum(1 for c in text if c in '→←↑↓∇∂∫∑∏αβγδ')
235
+ if symbols / len(text) > 0.2:
236
+ score += 0.3
237
+
238
+ return min(score, 1.0)
239
+
240
+ def _count_fillers(self, text: str) -> int:
241
+ """Count filler phrases."""
242
+ text_lower = text.lower()
243
+ return sum(1 for f in self.FILLER_PHRASES if f in text_lower)
244
+
245
+
246
+ # ==============================================================================
247
+ # SELF-IMPROVEMENT TRAINER
248
+ # ==============================================================================
249
+ class SelfImprovementTrainer:
250
+ """Stable recursive self-improvement with safeguards."""
251
+
252
+ def __init__(self, model_path: str = MODEL_PATH, base_checkpoint: str = None):
253
+ self.model_path = model_path
254
+ self.base_checkpoint = base_checkpoint or os.path.join(CHECKPOINTS_DIR, "step_100")
255
+
256
+ self.model = None
257
+ self.tokenizer = None
258
+ self.evaluator = None
259
+
260
+ self.best_checkpoint = self.base_checkpoint
261
+ self.best_score = 0.0
262
+ self.history = []
263
+
264
+ def load_model(self, checkpoint_path: str = None):
265
+ """Load model with checkpoint."""
266
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
267
+ from peft import PeftModel
268
+
269
+ checkpoint_path = checkpoint_path or self.base_checkpoint
270
+
271
+ print(f"[LOAD] Loading model: {self.model_path}")
272
+ print(f"[LOAD] Checkpoint: {checkpoint_path}")
273
+
274
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True)
275
+ self.tokenizer.pad_token = self.tokenizer.eos_token
276
+
277
+ bnb_config = BitsAndBytesConfig(
278
+ load_in_4bit=True,
279
+ bnb_4bit_quant_type="nf4",
280
+ bnb_4bit_compute_dtype=torch.bfloat16,
281
+ )
282
+
283
+ base = AutoModelForCausalLM.from_pretrained(
284
+ self.model_path,
285
+ quantization_config=bnb_config,
286
+ device_map="auto",
287
+ torch_dtype=torch.bfloat16,
288
+ local_files_only=True
289
+ )
290
+
291
+ if os.path.exists(checkpoint_path):
292
+ self.model = PeftModel.from_pretrained(base, checkpoint_path)
293
+ print(f"[LOAD] ✓ Loaded checkpoint")
294
+ else:
295
+ self.model = base
296
+ print(f"[LOAD] ⚠ No checkpoint found, using base model")
297
+
298
+ self.model.eval()
299
+ self.evaluator = Evaluator(self.tokenizer)
300
+
301
+ def reload_checkpoint(self, checkpoint_path: str):
302
+ """Hot-reload a different checkpoint."""
303
+ if self.model is not None:
304
+ del self.model
305
+ torch.cuda.empty_cache()
306
+ self.load_model(checkpoint_path)
307
+
308
+ def generate(self, prompt: str, max_tokens: int = 200) -> str:
309
+ """Generate response."""
310
+ full_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
311
+
312
+ input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.model.device)
313
+
314
+ with torch.no_grad():
315
+ output_ids = self.model.generate(
316
+ input_ids,
317
+ max_new_tokens=max_tokens,
318
+ temperature=0.8,
319
+ top_p=0.9,
320
+ do_sample=True,
321
+ pad_token_id=self.tokenizer.eos_token_id
322
+ )
323
+
324
+ response = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
325
+
326
+ for end in ["<|im_end|>", "<|im_start|>"]:
327
+ if end in response:
328
+ response = response.split(end)[0]
329
+
330
+ return response.strip()
331
+
332
+ def evaluate_model(self) -> Dict[str, Any]:
333
+ """Comprehensive evaluation on test prompts."""
334
+ print("\n[EVAL] Running evaluation...")
335
+
336
+ results = []
337
+ total_score = 0.0
338
+
339
+ for test in TEST_PROMPTS:
340
+ response = self.generate(test["prompt"], max_tokens=200)
341
+
342
+ eval_result = self.evaluator.evaluate(
343
+ test["prompt"], response, test["category"],
344
+ test.get("min_tokens", 5), test.get("max_tokens", 200)
345
+ )
346
+
347
+ results.append({
348
+ "prompt": test["prompt"],
349
+ "response": response[:150],
350
+ "category": test["category"],
351
+ "tokens": eval_result.tokens,
352
+ "overall": eval_result.overall_score,
353
+ "density": eval_result.density_score,
354
+ "coherence": eval_result.coherence_score,
355
+ "passes": eval_result.passes,
356
+ "issues": eval_result.issues,
357
+ })
358
+
359
+ total_score += eval_result.overall_score
360
+
361
+ status = "✓" if eval_result.passes else "✗"
362
+ issues = f" [{', '.join(eval_result.issues)}]" if eval_result.issues else ""
363
+ print(f" {status} {test['prompt'][:30]:30s} | score={eval_result.overall_score:.2f} tok={eval_result.tokens:3d}{issues}")
364
+
365
+ avg_score = total_score / len(results)
366
+ pass_rate = sum(1 for r in results if r["passes"]) / len(results)
367
+
368
+ evaluation = {
369
+ "avg_score": avg_score,
370
+ "pass_rate": pass_rate,
371
+ "results": results,
372
+ "timestamp": datetime.now().isoformat(),
373
+ }
374
+
375
+ print(f"\n[EVAL] Avg Score: {avg_score:.3f} | Pass Rate: {pass_rate:.1%}")
376
+
377
+ return evaluation
378
+
379
+ def train_iteration(self, steps: int = 25, lr: float = 2e-6) -> Dict[str, Any]:
380
+ """Run one training iteration."""
381
+ from peft import PeftModel
382
+
383
+ print(f"\n[TRAIN] Running {steps} steps (LR={lr})...")
384
+
385
+ # Make model trainable
386
+ self.model.train()
387
+ for param in self.model.parameters():
388
+ param.requires_grad = False
389
+ for name, param in self.model.named_parameters():
390
+ if "lora" in name.lower():
391
+ param.requires_grad = True
392
+
393
+ optimizer = torch.optim.AdamW(
394
+ [p for p in self.model.parameters() if p.requires_grad],
395
+ lr=lr
396
+ )
397
+
398
+ total_loss = 0
399
+
400
+ for step in range(steps):
401
+ ex = random.choice(DENSE_EXAMPLES)
402
+
403
+ full_text = f"<|im_start|>user\n{ex['prompt']}<|im_end|>\n<|im_start|>assistant\n{ex['response']}<|im_end|>"
404
+
405
+ inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)
406
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
407
+
408
+ outputs = self.model(**inputs, labels=inputs["input_ids"])
409
+ loss = outputs.loss
410
+
411
+ optimizer.zero_grad()
412
+ loss.backward()
413
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
414
+ optimizer.step()
415
+
416
+ total_loss += loss.item()
417
+
418
+ if (step + 1) % 10 == 0:
419
+ print(f" Step {step+1}: loss={loss.item():.4f}")
420
+
421
+ self.model.eval()
422
+
423
+ # Find next checkpoint number
424
+ existing = list(Path(CHECKPOINTS_DIR).glob("step_*"))
425
+ if existing:
426
+ latest = max(int(p.name.split("_")[1]) for p in existing if p.name.split("_")[1].isdigit())
427
+ new_step = latest + steps
428
+ else:
429
+ new_step = steps
430
+
431
+ # Save
432
+ checkpoint_path = os.path.join(CHECKPOINTS_DIR, f"step_{new_step}")
433
+ self.model.save_pretrained(checkpoint_path)
434
+
435
+ print(f"[TRAIN] Saved: {checkpoint_path}")
436
+
437
+ return {
438
+ "checkpoint": checkpoint_path,
439
+ "steps": steps,
440
+ "avg_loss": total_loss / steps,
441
+ }
442
+
443
+ def compare_checkpoints(self, ckpt_a: str, ckpt_b: str) -> Dict[str, Any]:
444
+ """A/B compare two checkpoints."""
445
+ print(f"\n[COMPARE] A: {ckpt_a}")
446
+ print(f"[COMPARE] B: {ckpt_b}")
447
+
448
+ # Evaluate A
449
+ self.reload_checkpoint(ckpt_a)
450
+ eval_a = self.evaluate_model()
451
+
452
+ # Evaluate B
453
+ self.reload_checkpoint(ckpt_b)
454
+ eval_b = self.evaluate_model()
455
+
456
+ diff = eval_b["avg_score"] - eval_a["avg_score"]
457
+
458
+ # Decide
459
+ if eval_b["avg_score"] < 0.4: # Quality too low
460
+ winner = "A"
461
+ reason = "B quality below minimum"
462
+ elif diff > 0.02:
463
+ winner = "B"
464
+ reason = f"B improves by {diff:.3f}"
465
+ elif diff < -0.05:
466
+ winner = "A"
467
+ reason = f"B degrades by {abs(diff):.3f}"
468
+ else:
469
+ winner = "A"
470
+ reason = "No significant improvement"
471
+
472
+ print(f"\n[COMPARE] Winner: {winner} ({reason})")
473
+
474
+ return {
475
+ "winner": winner,
476
+ "reason": reason,
477
+ "score_a": eval_a["avg_score"],
478
+ "score_b": eval_b["avg_score"],
479
+ "diff": diff,
480
+ }
481
+
482
+ def improve(self, iterations: int = 5, steps_per_iter: int = 25) -> Dict[str, Any]:
483
+ """Main self-improvement loop."""
484
+ print("\n" + "="*70)
485
+ print("STABLE SELF-IMPROVEMENT")
486
+ print("="*70)
487
+ print(f" Iterations: {iterations}")
488
+ print(f" Steps per iteration: {steps_per_iter}")
489
+ print("="*70)
490
+
491
+ # Initial evaluation
492
+ current_checkpoint = self.base_checkpoint
493
+ self.load_model(current_checkpoint)
494
+
495
+ baseline = self.evaluate_model()
496
+ self.best_score = baseline["avg_score"]
497
+ self.best_checkpoint = current_checkpoint
498
+
499
+ self.history = [{
500
+ "iteration": 0,
501
+ "type": "baseline",
502
+ "score": baseline["avg_score"],
503
+ "checkpoint": current_checkpoint,
504
+ }]
505
+
506
+ for i in range(1, iterations + 1):
507
+ print(f"\n{'='*70}")
508
+ print(f"ITERATION {i}/{iterations}")
509
+ print("="*70)
510
+
511
+ # Check if good enough
512
+ if baseline["avg_score"] >= 0.75:
513
+ print(f"✓ Target reached! Score: {baseline['avg_score']:.3f}")
514
+ break
515
+
516
+ # Save rollback point
517
+ rollback_path = os.path.join(ROLLBACK_DIR, f"rollback_{i}")
518
+ if os.path.exists(current_checkpoint):
519
+ shutil.copytree(current_checkpoint, rollback_path, dirs_exist_ok=True)
520
+
521
+ # Train
522
+ train_result = self.train_iteration(steps_per_iter)
523
+ new_checkpoint = train_result["checkpoint"]
524
+
525
+ # Compare
526
+ comparison = self.compare_checkpoints(current_checkpoint, new_checkpoint)
527
+
528
+ self.history.append({
529
+ "iteration": i,
530
+ "type": "training",
531
+ "old_score": comparison["score_a"],
532
+ "new_score": comparison["score_b"],
533
+ "winner": comparison["winner"],
534
+ "reason": comparison["reason"],
535
+ })
536
+
537
+ if comparison["winner"] == "B":
538
+ current_checkpoint = new_checkpoint
539
+ if comparison["score_b"] > self.best_score:
540
+ self.best_score = comparison["score_b"]
541
+ self.best_checkpoint = new_checkpoint
542
+ print(f"★ New best: {self.best_score:.3f}")
543
+ baseline = {"avg_score": comparison["score_b"]}
544
+ else:
545
+ self.reload_checkpoint(current_checkpoint)
546
+ baseline = {"avg_score": comparison["score_a"]}
547
+
548
+ # Final
549
+ self.reload_checkpoint(self.best_checkpoint)
550
+ final_eval = self.evaluate_model()
551
+
552
+ result = {
553
+ "success": final_eval["avg_score"] >= 0.7,
554
+ "iterations": iterations,
555
+ "final_score": final_eval["avg_score"],
556
+ "best_score": self.best_score,
557
+ "best_checkpoint": self.best_checkpoint,
558
+ "history": self.history,
559
+ }
560
+
561
+ # Save log
562
+ log_path = os.path.join(LOGS_DIR, f"improvement_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
563
+ with open(log_path, "w") as f:
564
+ json.dump(result, f, indent=2, default=str)
565
+
566
+ print(f"\n{'='*70}")
567
+ print("IMPROVEMENT COMPLETE")
568
+ print(f" Final score: {final_eval['avg_score']:.3f}")
569
+ print(f" Best score: {self.best_score:.3f}")
570
+ print(f" Best checkpoint: {self.best_checkpoint}")
571
+ print(f" Log saved: {log_path}")
572
+ print("="*70)
573
+
574
+ return result
575
+
576
+
577
+ # ==============================================================================
578
+ # MAIN
579
+ # ==============================================================================
580
+ def main():
581
+ parser = argparse.ArgumentParser(description="Stable Self-Improvement Training")
582
+ parser.add_argument("--iterations", type=int, default=5, help="Number of improvement iterations")
583
+ parser.add_argument("--steps-per-iter", type=int, default=25, help="Training steps per iteration")
584
+ parser.add_argument("--checkpoint", type=str, default=None, help="Starting checkpoint")
585
+ parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path")
586
+ parser.add_argument("--eval-only", action="store_true", help="Only run evaluation")
587
+ parser.add_argument("--compare", nargs=2, metavar=("CKPT_A", "CKPT_B"), help="Compare two checkpoints")
588
+
589
+ args = parser.parse_args()
590
+
591
+ trainer = SelfImprovementTrainer(args.model_path, args.checkpoint)
592
+
593
+ if args.eval_only:
594
+ trainer.load_model(args.checkpoint)
595
+ trainer.evaluate_model()
596
+ elif args.compare:
597
+ trainer.load_model(args.compare[0])
598
+ trainer.compare_checkpoints(args.compare[0], args.compare[1])
599
+ else:
600
+ trainer.improve(args.iterations, args.steps_per_iter)
601
+
602
+
603
+ if __name__ == "__main__":
604
+ main()