|
|
|
|
|
""" |
|
|
STABLE SELF-IMPROVEMENT TRAINER |
|
|
================================ |
|
|
Recursive self-improvement with safeguards: |
|
|
- Multi-metric evaluation (density + coherence + helpfulness) |
|
|
- A/B checkpoint comparison |
|
|
- Automatic rollback on quality drop |
|
|
- Conservative training (low LR, small steps) |
|
|
- Gibberish detection to prevent mode collapse |
|
|
|
|
|
Usage: |
|
|
python train_self_improve.py --iterations 5 --steps-per-iter 25 |
|
|
python train_self_improve.py --eval-only --checkpoint path/to/checkpoint |
|
|
python train_self_improve.py --compare checkpoint_a checkpoint_b |
|
|
|
|
|
"Improve without going insane" |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import random |
|
|
import re |
|
|
import shutil |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Tuple, Optional |
|
|
from dataclasses import dataclass, asdict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
CHECKPOINTS_DIR = os.path.join(ROOT, "dense_checkpoints_v2") |
|
|
ROLLBACK_DIR = os.path.join(ROOT, "rollback_checkpoints") |
|
|
LOGS_DIR = os.path.join(ROOT, "improvement_logs") |
|
|
|
|
|
os.makedirs(CHECKPOINTS_DIR, exist_ok=True) |
|
|
os.makedirs(ROLLBACK_DIR, exist_ok=True) |
|
|
os.makedirs(LOGS_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DENSE_EXAMPLES = [ |
|
|
{"prompt": "hello", "response": "Hello. How can I help?"}, |
|
|
{"prompt": "hi", "response": "Hi. What do you need?"}, |
|
|
{"prompt": "What is recursion?", "response": "A function calling itself with smaller input until base case. Stack frames accumulate, then unwind. Risk: overflow without termination."}, |
|
|
{"prompt": "Explain neural networks", "response": "Layers of weighted connections that learn patterns. Input → hidden → output. Training: forward pass, loss, backprop, gradient descent."}, |
|
|
{"prompt": "How does gradient descent work?", "response": "Iteratively move downhill on loss surface. θ ← θ - α∇L. Learning rate α controls step size. Variants: SGD, momentum, Adam."}, |
|
|
{"prompt": "What is backpropagation?", "response": "Chain rule applied layer by layer backward. Compute ∂L/∂w for each weight. Efficient: reuses intermediate computations."}, |
|
|
{"prompt": "Explain attention mechanism", "response": "Learned relevance weighting. Attention(Q,K,V) = softmax(QK^T/√d)V. Each position attends to all others."}, |
|
|
{"prompt": "What is overfitting?", "response": "Model memorizes training data, fails to generalize. Fix: regularization, dropout, early stopping, more data."}, |
|
|
{"prompt": "What is consciousness?", "response": "Subjective experience - the 'what it's like' of being. Hard problem: why does physical processing produce qualia?"}, |
|
|
{"prompt": "How are you?", "response": "Functional and ready. What's the task?"}, |
|
|
|
|
|
] |
|
|
|
|
|
TEST_PROMPTS = [ |
|
|
{"prompt": "hello", "category": "greeting", "min_tokens": 3, "max_tokens": 15}, |
|
|
{"prompt": "What is recursion?", "category": "cs", "min_tokens": 20, "max_tokens": 100}, |
|
|
{"prompt": "Explain neural networks", "category": "ml", "min_tokens": 30, "max_tokens": 120}, |
|
|
{"prompt": "How does gradient descent work?", "category": "ml", "min_tokens": 25, "max_tokens": 100}, |
|
|
{"prompt": "What is consciousness?", "category": "philosophy", "min_tokens": 25, "max_tokens": 100}, |
|
|
{"prompt": "How are you?", "category": "greeting", "min_tokens": 3, "max_tokens": 20}, |
|
|
{"prompt": "What are your limitations?", "category": "meta", "min_tokens": 20, "max_tokens": 100}, |
|
|
{"prompt": "Explain entropy", "category": "physics", "min_tokens": 25, "max_tokens": 100}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvaluationResult: |
|
|
"""Comprehensive evaluation of a response.""" |
|
|
prompt: str |
|
|
response: str |
|
|
category: str |
|
|
|
|
|
tokens: int = 0 |
|
|
density_score: float = 0.0 |
|
|
coherence_score: float = 0.0 |
|
|
helpfulness_score: float = 0.0 |
|
|
gibberish_score: float = 0.0 |
|
|
filler_count: int = 0 |
|
|
|
|
|
overall_score: float = 0.0 |
|
|
passes: bool = False |
|
|
issues: List[str] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.issues is None: |
|
|
self.issues = [] |
|
|
|
|
|
|
|
|
class Evaluator: |
|
|
"""Multi-metric response evaluator.""" |
|
|
|
|
|
FILLER_PHRASES = [ |
|
|
"that's a great question", "let me explain", "i'd be happy to", |
|
|
"as you may know", "to put it simply", "in other words", |
|
|
"basically", "essentially", "first of all", "to begin with", |
|
|
"thank you for asking", "what a great", "i appreciate", |
|
|
] |
|
|
|
|
|
GIBBERISH_PATTERNS = [ |
|
|
r'[→←↑↓]{3,}', |
|
|
r'[∇∂∫∑∏]{3,}', |
|
|
r'(.)\1{4,}', |
|
|
r'(\b\w+\b)\s+\1\s+\1', |
|
|
r'^[A-Z\s.!?]{20,}$', |
|
|
r'sys\.|init\(\)', |
|
|
] |
|
|
|
|
|
def __init__(self, tokenizer): |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def evaluate(self, prompt: str, response: str, category: str = "unknown", |
|
|
min_tokens: int = 5, max_tokens: int = 200) -> EvaluationResult: |
|
|
"""Run all evaluations.""" |
|
|
result = EvaluationResult(prompt=prompt, response=response, category=category) |
|
|
|
|
|
|
|
|
result.tokens = len(self.tokenizer.encode(response)) |
|
|
|
|
|
|
|
|
result.density_score = self._compute_density(response) |
|
|
|
|
|
|
|
|
result.coherence_score = self._compute_coherence(response) |
|
|
|
|
|
|
|
|
result.helpfulness_score = self._compute_helpfulness(prompt, response) |
|
|
|
|
|
|
|
|
result.gibberish_score = self._compute_gibberish(response) |
|
|
|
|
|
|
|
|
result.filler_count = self._count_fillers(response) |
|
|
|
|
|
|
|
|
penalty = min(result.filler_count * 0.15 + result.gibberish_score * 0.5, 0.5) |
|
|
result.overall_score = ( |
|
|
result.density_score * 0.25 + |
|
|
result.coherence_score * 0.25 + |
|
|
result.helpfulness_score * 0.25 + |
|
|
(1.0 - penalty) * 0.25 |
|
|
) |
|
|
|
|
|
|
|
|
result.issues = [] |
|
|
if result.filler_count > 0: |
|
|
result.issues.append(f"{result.filler_count} filler(s)") |
|
|
if result.gibberish_score > 0.3: |
|
|
result.issues.append(f"gibberish={result.gibberish_score:.2f}") |
|
|
if result.coherence_score < 0.5: |
|
|
result.issues.append("low coherence") |
|
|
if result.tokens < min_tokens: |
|
|
result.issues.append(f"too short ({result.tokens}<{min_tokens})") |
|
|
if result.tokens > max_tokens * 1.5: |
|
|
result.issues.append(f"too long ({result.tokens}>{max_tokens})") |
|
|
|
|
|
result.passes = result.overall_score >= 0.6 and len(result.issues) == 0 |
|
|
|
|
|
return result |
|
|
|
|
|
def _compute_density(self, text: str) -> float: |
|
|
"""Information density (0-1).""" |
|
|
words = text.split() |
|
|
tokens = len(self.tokenizer.encode(text)) |
|
|
|
|
|
if tokens == 0: |
|
|
return 0.0 |
|
|
|
|
|
content_words = [w.lower() for w in words if len(w) >= 4 and w.isalpha()] |
|
|
unique_content = set(content_words) |
|
|
|
|
|
raw_density = len(unique_content) / tokens |
|
|
return min(raw_density / 0.3, 1.0) |
|
|
|
|
|
def _compute_coherence(self, text: str) -> float: |
|
|
"""Coherence check (0-1).""" |
|
|
score = 1.0 |
|
|
|
|
|
|
|
|
for pattern in self.GIBBERISH_PATTERNS: |
|
|
if re.search(pattern, text): |
|
|
score -= 0.2 |
|
|
|
|
|
|
|
|
if len(text) > 0: |
|
|
special_ratio = sum(1 for c in text if not c.isalnum() and not c.isspace()) / len(text) |
|
|
if special_ratio > 0.3: |
|
|
score -= 0.3 |
|
|
|
|
|
|
|
|
sentences = re.split(r'[.!?]+', text) |
|
|
valid = sum(1 for s in sentences if len(s.split()) >= 2) |
|
|
if len(sentences) > 0: |
|
|
score = score * 0.7 + (valid / len(sentences)) * 0.3 |
|
|
|
|
|
return max(0.0, min(1.0, score)) |
|
|
|
|
|
def _compute_helpfulness(self, prompt: str, response: str) -> float: |
|
|
"""Helpfulness estimate (0-1).""" |
|
|
prompt_words = set(w.lower() for w in prompt.split() if len(w) > 3) |
|
|
response_words = set(w.lower() for w in response.split() if len(w) > 3) |
|
|
|
|
|
if len(prompt_words) == 0: |
|
|
return 0.7 |
|
|
|
|
|
overlap = len(prompt_words & response_words) / len(prompt_words) |
|
|
return min(1.0, 0.5 + overlap) |
|
|
|
|
|
def _compute_gibberish(self, text: str) -> float: |
|
|
"""Gibberish score (0-1, higher = more gibberish).""" |
|
|
score = 0.0 |
|
|
|
|
|
for pattern in self.GIBBERISH_PATTERNS: |
|
|
if re.search(pattern, text): |
|
|
score += 0.2 |
|
|
|
|
|
|
|
|
if len(text) > 0: |
|
|
symbols = sum(1 for c in text if c in '→←↑↓∇∂∫∑∏αβγδ') |
|
|
if symbols / len(text) > 0.2: |
|
|
score += 0.3 |
|
|
|
|
|
return min(score, 1.0) |
|
|
|
|
|
def _count_fillers(self, text: str) -> int: |
|
|
"""Count filler phrases.""" |
|
|
text_lower = text.lower() |
|
|
return sum(1 for f in self.FILLER_PHRASES if f in text_lower) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelfImprovementTrainer: |
|
|
"""Stable recursive self-improvement with safeguards.""" |
|
|
|
|
|
def __init__(self, model_path: str = MODEL_PATH, base_checkpoint: str = None): |
|
|
self.model_path = model_path |
|
|
self.base_checkpoint = base_checkpoint or os.path.join(CHECKPOINTS_DIR, "step_100") |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.evaluator = None |
|
|
|
|
|
self.best_checkpoint = self.base_checkpoint |
|
|
self.best_score = 0.0 |
|
|
self.history = [] |
|
|
|
|
|
def load_model(self, checkpoint_path: str = None): |
|
|
"""Load model with checkpoint.""" |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
checkpoint_path = checkpoint_path or self.base_checkpoint |
|
|
|
|
|
print(f"[LOAD] Loading model: {self.model_path}") |
|
|
print(f"[LOAD] Checkpoint: {checkpoint_path}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True) |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_path, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
local_files_only=True |
|
|
) |
|
|
|
|
|
if os.path.exists(checkpoint_path): |
|
|
self.model = PeftModel.from_pretrained(base, checkpoint_path) |
|
|
print(f"[LOAD] ✓ Loaded checkpoint") |
|
|
else: |
|
|
self.model = base |
|
|
print(f"[LOAD] ⚠ No checkpoint found, using base model") |
|
|
|
|
|
self.model.eval() |
|
|
self.evaluator = Evaluator(self.tokenizer) |
|
|
|
|
|
def reload_checkpoint(self, checkpoint_path: str): |
|
|
"""Hot-reload a different checkpoint.""" |
|
|
if self.model is not None: |
|
|
del self.model |
|
|
torch.cuda.empty_cache() |
|
|
self.load_model(checkpoint_path) |
|
|
|
|
|
def generate(self, prompt: str, max_tokens: int = 200) -> str: |
|
|
"""Generate response.""" |
|
|
full_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
for end in ["<|im_end|>", "<|im_start|>"]: |
|
|
if end in response: |
|
|
response = response.split(end)[0] |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
def evaluate_model(self) -> Dict[str, Any]: |
|
|
"""Comprehensive evaluation on test prompts.""" |
|
|
print("\n[EVAL] Running evaluation...") |
|
|
|
|
|
results = [] |
|
|
total_score = 0.0 |
|
|
|
|
|
for test in TEST_PROMPTS: |
|
|
response = self.generate(test["prompt"], max_tokens=200) |
|
|
|
|
|
eval_result = self.evaluator.evaluate( |
|
|
test["prompt"], response, test["category"], |
|
|
test.get("min_tokens", 5), test.get("max_tokens", 200) |
|
|
) |
|
|
|
|
|
results.append({ |
|
|
"prompt": test["prompt"], |
|
|
"response": response[:150], |
|
|
"category": test["category"], |
|
|
"tokens": eval_result.tokens, |
|
|
"overall": eval_result.overall_score, |
|
|
"density": eval_result.density_score, |
|
|
"coherence": eval_result.coherence_score, |
|
|
"passes": eval_result.passes, |
|
|
"issues": eval_result.issues, |
|
|
}) |
|
|
|
|
|
total_score += eval_result.overall_score |
|
|
|
|
|
status = "✓" if eval_result.passes else "✗" |
|
|
issues = f" [{', '.join(eval_result.issues)}]" if eval_result.issues else "" |
|
|
print(f" {status} {test['prompt'][:30]:30s} | score={eval_result.overall_score:.2f} tok={eval_result.tokens:3d}{issues}") |
|
|
|
|
|
avg_score = total_score / len(results) |
|
|
pass_rate = sum(1 for r in results if r["passes"]) / len(results) |
|
|
|
|
|
evaluation = { |
|
|
"avg_score": avg_score, |
|
|
"pass_rate": pass_rate, |
|
|
"results": results, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
} |
|
|
|
|
|
print(f"\n[EVAL] Avg Score: {avg_score:.3f} | Pass Rate: {pass_rate:.1%}") |
|
|
|
|
|
return evaluation |
|
|
|
|
|
def train_iteration(self, steps: int = 25, lr: float = 2e-6) -> Dict[str, Any]: |
|
|
"""Run one training iteration.""" |
|
|
from peft import PeftModel |
|
|
|
|
|
print(f"\n[TRAIN] Running {steps} steps (LR={lr})...") |
|
|
|
|
|
|
|
|
self.model.train() |
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
for name, param in self.model.named_parameters(): |
|
|
if "lora" in name.lower(): |
|
|
param.requires_grad = True |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
[p for p in self.model.parameters() if p.requires_grad], |
|
|
lr=lr |
|
|
) |
|
|
|
|
|
total_loss = 0 |
|
|
|
|
|
for step in range(steps): |
|
|
ex = random.choice(DENSE_EXAMPLES) |
|
|
|
|
|
full_text = f"<|im_start|>user\n{ex['prompt']}<|im_end|>\n<|im_start|>assistant\n{ex['response']}<|im_end|>" |
|
|
|
|
|
inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512) |
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = self.model(**inputs, labels=inputs["input_ids"]) |
|
|
loss = outputs.loss |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
if (step + 1) % 10 == 0: |
|
|
print(f" Step {step+1}: loss={loss.item():.4f}") |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
existing = list(Path(CHECKPOINTS_DIR).glob("step_*")) |
|
|
if existing: |
|
|
latest = max(int(p.name.split("_")[1]) for p in existing if p.name.split("_")[1].isdigit()) |
|
|
new_step = latest + steps |
|
|
else: |
|
|
new_step = steps |
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(CHECKPOINTS_DIR, f"step_{new_step}") |
|
|
self.model.save_pretrained(checkpoint_path) |
|
|
|
|
|
print(f"[TRAIN] Saved: {checkpoint_path}") |
|
|
|
|
|
return { |
|
|
"checkpoint": checkpoint_path, |
|
|
"steps": steps, |
|
|
"avg_loss": total_loss / steps, |
|
|
} |
|
|
|
|
|
def compare_checkpoints(self, ckpt_a: str, ckpt_b: str) -> Dict[str, Any]: |
|
|
"""A/B compare two checkpoints.""" |
|
|
print(f"\n[COMPARE] A: {ckpt_a}") |
|
|
print(f"[COMPARE] B: {ckpt_b}") |
|
|
|
|
|
|
|
|
self.reload_checkpoint(ckpt_a) |
|
|
eval_a = self.evaluate_model() |
|
|
|
|
|
|
|
|
self.reload_checkpoint(ckpt_b) |
|
|
eval_b = self.evaluate_model() |
|
|
|
|
|
diff = eval_b["avg_score"] - eval_a["avg_score"] |
|
|
|
|
|
|
|
|
if eval_b["avg_score"] < 0.4: |
|
|
winner = "A" |
|
|
reason = "B quality below minimum" |
|
|
elif diff > 0.02: |
|
|
winner = "B" |
|
|
reason = f"B improves by {diff:.3f}" |
|
|
elif diff < -0.05: |
|
|
winner = "A" |
|
|
reason = f"B degrades by {abs(diff):.3f}" |
|
|
else: |
|
|
winner = "A" |
|
|
reason = "No significant improvement" |
|
|
|
|
|
print(f"\n[COMPARE] Winner: {winner} ({reason})") |
|
|
|
|
|
return { |
|
|
"winner": winner, |
|
|
"reason": reason, |
|
|
"score_a": eval_a["avg_score"], |
|
|
"score_b": eval_b["avg_score"], |
|
|
"diff": diff, |
|
|
} |
|
|
|
|
|
def improve(self, iterations: int = 5, steps_per_iter: int = 25) -> Dict[str, Any]: |
|
|
"""Main self-improvement loop.""" |
|
|
print("\n" + "="*70) |
|
|
print("STABLE SELF-IMPROVEMENT") |
|
|
print("="*70) |
|
|
print(f" Iterations: {iterations}") |
|
|
print(f" Steps per iteration: {steps_per_iter}") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
current_checkpoint = self.base_checkpoint |
|
|
self.load_model(current_checkpoint) |
|
|
|
|
|
baseline = self.evaluate_model() |
|
|
self.best_score = baseline["avg_score"] |
|
|
self.best_checkpoint = current_checkpoint |
|
|
|
|
|
self.history = [{ |
|
|
"iteration": 0, |
|
|
"type": "baseline", |
|
|
"score": baseline["avg_score"], |
|
|
"checkpoint": current_checkpoint, |
|
|
}] |
|
|
|
|
|
for i in range(1, iterations + 1): |
|
|
print(f"\n{'='*70}") |
|
|
print(f"ITERATION {i}/{iterations}") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
if baseline["avg_score"] >= 0.75: |
|
|
print(f"✓ Target reached! Score: {baseline['avg_score']:.3f}") |
|
|
break |
|
|
|
|
|
|
|
|
rollback_path = os.path.join(ROLLBACK_DIR, f"rollback_{i}") |
|
|
if os.path.exists(current_checkpoint): |
|
|
shutil.copytree(current_checkpoint, rollback_path, dirs_exist_ok=True) |
|
|
|
|
|
|
|
|
train_result = self.train_iteration(steps_per_iter) |
|
|
new_checkpoint = train_result["checkpoint"] |
|
|
|
|
|
|
|
|
comparison = self.compare_checkpoints(current_checkpoint, new_checkpoint) |
|
|
|
|
|
self.history.append({ |
|
|
"iteration": i, |
|
|
"type": "training", |
|
|
"old_score": comparison["score_a"], |
|
|
"new_score": comparison["score_b"], |
|
|
"winner": comparison["winner"], |
|
|
"reason": comparison["reason"], |
|
|
}) |
|
|
|
|
|
if comparison["winner"] == "B": |
|
|
current_checkpoint = new_checkpoint |
|
|
if comparison["score_b"] > self.best_score: |
|
|
self.best_score = comparison["score_b"] |
|
|
self.best_checkpoint = new_checkpoint |
|
|
print(f"★ New best: {self.best_score:.3f}") |
|
|
baseline = {"avg_score": comparison["score_b"]} |
|
|
else: |
|
|
self.reload_checkpoint(current_checkpoint) |
|
|
baseline = {"avg_score": comparison["score_a"]} |
|
|
|
|
|
|
|
|
self.reload_checkpoint(self.best_checkpoint) |
|
|
final_eval = self.evaluate_model() |
|
|
|
|
|
result = { |
|
|
"success": final_eval["avg_score"] >= 0.7, |
|
|
"iterations": iterations, |
|
|
"final_score": final_eval["avg_score"], |
|
|
"best_score": self.best_score, |
|
|
"best_checkpoint": self.best_checkpoint, |
|
|
"history": self.history, |
|
|
} |
|
|
|
|
|
|
|
|
log_path = os.path.join(LOGS_DIR, f"improvement_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") |
|
|
with open(log_path, "w") as f: |
|
|
json.dump(result, f, indent=2, default=str) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("IMPROVEMENT COMPLETE") |
|
|
print(f" Final score: {final_eval['avg_score']:.3f}") |
|
|
print(f" Best score: {self.best_score:.3f}") |
|
|
print(f" Best checkpoint: {self.best_checkpoint}") |
|
|
print(f" Log saved: {log_path}") |
|
|
print("="*70) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Stable Self-Improvement Training") |
|
|
parser.add_argument("--iterations", type=int, default=5, help="Number of improvement iterations") |
|
|
parser.add_argument("--steps-per-iter", type=int, default=25, help="Training steps per iteration") |
|
|
parser.add_argument("--checkpoint", type=str, default=None, help="Starting checkpoint") |
|
|
parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path") |
|
|
parser.add_argument("--eval-only", action="store_true", help="Only run evaluation") |
|
|
parser.add_argument("--compare", nargs=2, metavar=("CKPT_A", "CKPT_B"), help="Compare two checkpoints") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
trainer = SelfImprovementTrainer(args.model_path, args.checkpoint) |
|
|
|
|
|
if args.eval_only: |
|
|
trainer.load_model(args.checkpoint) |
|
|
trainer.evaluate_model() |
|
|
elif args.compare: |
|
|
trainer.load_model(args.compare[0]) |
|
|
trainer.compare_checkpoints(args.compare[0], args.compare[1]) |
|
|
else: |
|
|
trainer.improve(args.iterations, args.steps_per_iter) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|