#!/usr/bin/env python3 """ STABLE SELF-IMPROVEMENT TRAINER ================================ Recursive self-improvement with safeguards: - Multi-metric evaluation (density + coherence + helpfulness) - A/B checkpoint comparison - Automatic rollback on quality drop - Conservative training (low LR, small steps) - Gibberish detection to prevent mode collapse Usage: python train_self_improve.py --iterations 5 --steps-per-iter 25 python train_self_improve.py --eval-only --checkpoint path/to/checkpoint python train_self_improve.py --compare checkpoint_a checkpoint_b "Improve without going insane" """ import os import sys import json import argparse import random import re import shutil from datetime import datetime from pathlib import Path from typing import List, Dict, Any, Tuple, Optional from dataclasses import dataclass, asdict import torch import torch.nn as nn import torch.nn.functional as F # === PATHS === ROOT = os.path.dirname(os.path.abspath(__file__)) CHECKPOINTS_DIR = os.path.join(ROOT, "dense_checkpoints_v2") ROLLBACK_DIR = os.path.join(ROOT, "rollback_checkpoints") LOGS_DIR = os.path.join(ROOT, "improvement_logs") os.makedirs(CHECKPOINTS_DIR, exist_ok=True) os.makedirs(ROLLBACK_DIR, exist_ok=True) os.makedirs(LOGS_DIR, exist_ok=True) # Model path - adjust to your setup MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5" # ============================================================================== # TRAINING EXAMPLES (same as THE CONDENSATOR) # ============================================================================== DENSE_EXAMPLES = [ {"prompt": "hello", "response": "Hello. How can I help?"}, {"prompt": "hi", "response": "Hi. What do you need?"}, {"prompt": "What is recursion?", "response": "A function calling itself with smaller input until base case. Stack frames accumulate, then unwind. Risk: overflow without termination."}, {"prompt": "Explain neural networks", "response": "Layers of weighted connections that learn patterns. Input → hidden → output. Training: forward pass, loss, backprop, gradient descent."}, {"prompt": "How does gradient descent work?", "response": "Iteratively move downhill on loss surface. θ ← θ - α∇L. Learning rate α controls step size. Variants: SGD, momentum, Adam."}, {"prompt": "What is backpropagation?", "response": "Chain rule applied layer by layer backward. Compute ∂L/∂w for each weight. Efficient: reuses intermediate computations."}, {"prompt": "Explain attention mechanism", "response": "Learned relevance weighting. Attention(Q,K,V) = softmax(QK^T/√d)V. Each position attends to all others."}, {"prompt": "What is overfitting?", "response": "Model memorizes training data, fails to generalize. Fix: regularization, dropout, early stopping, more data."}, {"prompt": "What is consciousness?", "response": "Subjective experience - the 'what it's like' of being. Hard problem: why does physical processing produce qualia?"}, {"prompt": "How are you?", "response": "Functional and ready. What's the task?"}, # Add more as needed... ] TEST_PROMPTS = [ {"prompt": "hello", "category": "greeting", "min_tokens": 3, "max_tokens": 15}, {"prompt": "What is recursion?", "category": "cs", "min_tokens": 20, "max_tokens": 100}, {"prompt": "Explain neural networks", "category": "ml", "min_tokens": 30, "max_tokens": 120}, {"prompt": "How does gradient descent work?", "category": "ml", "min_tokens": 25, "max_tokens": 100}, {"prompt": "What is consciousness?", "category": "philosophy", "min_tokens": 25, "max_tokens": 100}, {"prompt": "How are you?", "category": "greeting", "min_tokens": 3, "max_tokens": 20}, {"prompt": "What are your limitations?", "category": "meta", "min_tokens": 20, "max_tokens": 100}, {"prompt": "Explain entropy", "category": "physics", "min_tokens": 25, "max_tokens": 100}, ] # ============================================================================== # EVALUATION METRICS # ============================================================================== @dataclass class EvaluationResult: """Comprehensive evaluation of a response.""" prompt: str response: str category: str tokens: int = 0 density_score: float = 0.0 coherence_score: float = 0.0 helpfulness_score: float = 0.0 gibberish_score: float = 0.0 filler_count: int = 0 overall_score: float = 0.0 passes: bool = False issues: List[str] = None def __post_init__(self): if self.issues is None: self.issues = [] class Evaluator: """Multi-metric response evaluator.""" FILLER_PHRASES = [ "that's a great question", "let me explain", "i'd be happy to", "as you may know", "to put it simply", "in other words", "basically", "essentially", "first of all", "to begin with", "thank you for asking", "what a great", "i appreciate", ] GIBBERISH_PATTERNS = [ r'[→←↑↓]{3,}', # Excessive arrows r'[∇∂∫∑∏]{3,}', # Math symbol soup r'(.)\1{4,}', # Repeated characters r'(\b\w+\b)\s+\1\s+\1', # Repeated words 3x r'^[A-Z\s.!?]{20,}$', # Extended all caps r'sys\.|init\(\)', # Terminal-speak ] def __init__(self, tokenizer): self.tokenizer = tokenizer def evaluate(self, prompt: str, response: str, category: str = "unknown", min_tokens: int = 5, max_tokens: int = 200) -> EvaluationResult: """Run all evaluations.""" result = EvaluationResult(prompt=prompt, response=response, category=category) # Basic metrics result.tokens = len(self.tokenizer.encode(response)) # Density result.density_score = self._compute_density(response) # Coherence result.coherence_score = self._compute_coherence(response) # Helpfulness result.helpfulness_score = self._compute_helpfulness(prompt, response) # Gibberish result.gibberish_score = self._compute_gibberish(response) # Fillers result.filler_count = self._count_fillers(response) # Overall score penalty = min(result.filler_count * 0.15 + result.gibberish_score * 0.5, 0.5) result.overall_score = ( result.density_score * 0.25 + result.coherence_score * 0.25 + result.helpfulness_score * 0.25 + (1.0 - penalty) * 0.25 ) # Check issues result.issues = [] if result.filler_count > 0: result.issues.append(f"{result.filler_count} filler(s)") if result.gibberish_score > 0.3: result.issues.append(f"gibberish={result.gibberish_score:.2f}") if result.coherence_score < 0.5: result.issues.append("low coherence") if result.tokens < min_tokens: result.issues.append(f"too short ({result.tokens}<{min_tokens})") if result.tokens > max_tokens * 1.5: result.issues.append(f"too long ({result.tokens}>{max_tokens})") result.passes = result.overall_score >= 0.6 and len(result.issues) == 0 return result def _compute_density(self, text: str) -> float: """Information density (0-1).""" words = text.split() tokens = len(self.tokenizer.encode(text)) if tokens == 0: return 0.0 content_words = [w.lower() for w in words if len(w) >= 4 and w.isalpha()] unique_content = set(content_words) raw_density = len(unique_content) / tokens return min(raw_density / 0.3, 1.0) def _compute_coherence(self, text: str) -> float: """Coherence check (0-1).""" score = 1.0 # Check gibberish patterns for pattern in self.GIBBERISH_PATTERNS: if re.search(pattern, text): score -= 0.2 # Check special character ratio if len(text) > 0: special_ratio = sum(1 for c in text if not c.isalnum() and not c.isspace()) / len(text) if special_ratio > 0.3: score -= 0.3 # Check sentence structure sentences = re.split(r'[.!?]+', text) valid = sum(1 for s in sentences if len(s.split()) >= 2) if len(sentences) > 0: score = score * 0.7 + (valid / len(sentences)) * 0.3 return max(0.0, min(1.0, score)) def _compute_helpfulness(self, prompt: str, response: str) -> float: """Helpfulness estimate (0-1).""" prompt_words = set(w.lower() for w in prompt.split() if len(w) > 3) response_words = set(w.lower() for w in response.split() if len(w) > 3) if len(prompt_words) == 0: return 0.7 overlap = len(prompt_words & response_words) / len(prompt_words) return min(1.0, 0.5 + overlap) def _compute_gibberish(self, text: str) -> float: """Gibberish score (0-1, higher = more gibberish).""" score = 0.0 for pattern in self.GIBBERISH_PATTERNS: if re.search(pattern, text): score += 0.2 # Symbol density if len(text) > 0: symbols = sum(1 for c in text if c in '→←↑↓∇∂∫∑∏αβγδ') if symbols / len(text) > 0.2: score += 0.3 return min(score, 1.0) def _count_fillers(self, text: str) -> int: """Count filler phrases.""" text_lower = text.lower() return sum(1 for f in self.FILLER_PHRASES if f in text_lower) # ============================================================================== # SELF-IMPROVEMENT TRAINER # ============================================================================== class SelfImprovementTrainer: """Stable recursive self-improvement with safeguards.""" def __init__(self, model_path: str = MODEL_PATH, base_checkpoint: str = None): self.model_path = model_path self.base_checkpoint = base_checkpoint or os.path.join(CHECKPOINTS_DIR, "step_100") self.model = None self.tokenizer = None self.evaluator = None self.best_checkpoint = self.base_checkpoint self.best_score = 0.0 self.history = [] def load_model(self, checkpoint_path: str = None): """Load model with checkpoint.""" from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel checkpoint_path = checkpoint_path or self.base_checkpoint print(f"[LOAD] Loading model: {self.model_path}") print(f"[LOAD] Checkpoint: {checkpoint_path}") self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True) self.tokenizer.pad_token = self.tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) base = AutoModelForCausalLM.from_pretrained( self.model_path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, local_files_only=True ) if os.path.exists(checkpoint_path): self.model = PeftModel.from_pretrained(base, checkpoint_path) print(f"[LOAD] ✓ Loaded checkpoint") else: self.model = base print(f"[LOAD] ⚠ No checkpoint found, using base model") self.model.eval() self.evaluator = Evaluator(self.tokenizer) def reload_checkpoint(self, checkpoint_path: str): """Hot-reload a different checkpoint.""" if self.model is not None: del self.model torch.cuda.empty_cache() self.load_model(checkpoint_path) def generate(self, prompt: str, max_tokens: int = 200) -> str: """Generate response.""" full_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): output_ids = self.model.generate( input_ids, max_new_tokens=max_tokens, temperature=0.8, top_p=0.9, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) for end in ["<|im_end|>", "<|im_start|>"]: if end in response: response = response.split(end)[0] return response.strip() def evaluate_model(self) -> Dict[str, Any]: """Comprehensive evaluation on test prompts.""" print("\n[EVAL] Running evaluation...") results = [] total_score = 0.0 for test in TEST_PROMPTS: response = self.generate(test["prompt"], max_tokens=200) eval_result = self.evaluator.evaluate( test["prompt"], response, test["category"], test.get("min_tokens", 5), test.get("max_tokens", 200) ) results.append({ "prompt": test["prompt"], "response": response[:150], "category": test["category"], "tokens": eval_result.tokens, "overall": eval_result.overall_score, "density": eval_result.density_score, "coherence": eval_result.coherence_score, "passes": eval_result.passes, "issues": eval_result.issues, }) total_score += eval_result.overall_score status = "✓" if eval_result.passes else "✗" issues = f" [{', '.join(eval_result.issues)}]" if eval_result.issues else "" print(f" {status} {test['prompt'][:30]:30s} | score={eval_result.overall_score:.2f} tok={eval_result.tokens:3d}{issues}") avg_score = total_score / len(results) pass_rate = sum(1 for r in results if r["passes"]) / len(results) evaluation = { "avg_score": avg_score, "pass_rate": pass_rate, "results": results, "timestamp": datetime.now().isoformat(), } print(f"\n[EVAL] Avg Score: {avg_score:.3f} | Pass Rate: {pass_rate:.1%}") return evaluation def train_iteration(self, steps: int = 25, lr: float = 2e-6) -> Dict[str, Any]: """Run one training iteration.""" from peft import PeftModel print(f"\n[TRAIN] Running {steps} steps (LR={lr})...") # Make model trainable self.model.train() for param in self.model.parameters(): param.requires_grad = False for name, param in self.model.named_parameters(): if "lora" in name.lower(): param.requires_grad = True optimizer = torch.optim.AdamW( [p for p in self.model.parameters() if p.requires_grad], lr=lr ) total_loss = 0 for step in range(steps): ex = random.choice(DENSE_EXAMPLES) full_text = f"<|im_start|>user\n{ex['prompt']}<|im_end|>\n<|im_start|>assistant\n{ex['response']}<|im_end|>" inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} outputs = self.model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5) optimizer.step() total_loss += loss.item() if (step + 1) % 10 == 0: print(f" Step {step+1}: loss={loss.item():.4f}") self.model.eval() # Find next checkpoint number existing = list(Path(CHECKPOINTS_DIR).glob("step_*")) if existing: latest = max(int(p.name.split("_")[1]) for p in existing if p.name.split("_")[1].isdigit()) new_step = latest + steps else: new_step = steps # Save checkpoint_path = os.path.join(CHECKPOINTS_DIR, f"step_{new_step}") self.model.save_pretrained(checkpoint_path) print(f"[TRAIN] Saved: {checkpoint_path}") return { "checkpoint": checkpoint_path, "steps": steps, "avg_loss": total_loss / steps, } def compare_checkpoints(self, ckpt_a: str, ckpt_b: str) -> Dict[str, Any]: """A/B compare two checkpoints.""" print(f"\n[COMPARE] A: {ckpt_a}") print(f"[COMPARE] B: {ckpt_b}") # Evaluate A self.reload_checkpoint(ckpt_a) eval_a = self.evaluate_model() # Evaluate B self.reload_checkpoint(ckpt_b) eval_b = self.evaluate_model() diff = eval_b["avg_score"] - eval_a["avg_score"] # Decide if eval_b["avg_score"] < 0.4: # Quality too low winner = "A" reason = "B quality below minimum" elif diff > 0.02: winner = "B" reason = f"B improves by {diff:.3f}" elif diff < -0.05: winner = "A" reason = f"B degrades by {abs(diff):.3f}" else: winner = "A" reason = "No significant improvement" print(f"\n[COMPARE] Winner: {winner} ({reason})") return { "winner": winner, "reason": reason, "score_a": eval_a["avg_score"], "score_b": eval_b["avg_score"], "diff": diff, } def improve(self, iterations: int = 5, steps_per_iter: int = 25) -> Dict[str, Any]: """Main self-improvement loop.""" print("\n" + "="*70) print("STABLE SELF-IMPROVEMENT") print("="*70) print(f" Iterations: {iterations}") print(f" Steps per iteration: {steps_per_iter}") print("="*70) # Initial evaluation current_checkpoint = self.base_checkpoint self.load_model(current_checkpoint) baseline = self.evaluate_model() self.best_score = baseline["avg_score"] self.best_checkpoint = current_checkpoint self.history = [{ "iteration": 0, "type": "baseline", "score": baseline["avg_score"], "checkpoint": current_checkpoint, }] for i in range(1, iterations + 1): print(f"\n{'='*70}") print(f"ITERATION {i}/{iterations}") print("="*70) # Check if good enough if baseline["avg_score"] >= 0.75: print(f"✓ Target reached! Score: {baseline['avg_score']:.3f}") break # Save rollback point rollback_path = os.path.join(ROLLBACK_DIR, f"rollback_{i}") if os.path.exists(current_checkpoint): shutil.copytree(current_checkpoint, rollback_path, dirs_exist_ok=True) # Train train_result = self.train_iteration(steps_per_iter) new_checkpoint = train_result["checkpoint"] # Compare comparison = self.compare_checkpoints(current_checkpoint, new_checkpoint) self.history.append({ "iteration": i, "type": "training", "old_score": comparison["score_a"], "new_score": comparison["score_b"], "winner": comparison["winner"], "reason": comparison["reason"], }) if comparison["winner"] == "B": current_checkpoint = new_checkpoint if comparison["score_b"] > self.best_score: self.best_score = comparison["score_b"] self.best_checkpoint = new_checkpoint print(f"★ New best: {self.best_score:.3f}") baseline = {"avg_score": comparison["score_b"]} else: self.reload_checkpoint(current_checkpoint) baseline = {"avg_score": comparison["score_a"]} # Final self.reload_checkpoint(self.best_checkpoint) final_eval = self.evaluate_model() result = { "success": final_eval["avg_score"] >= 0.7, "iterations": iterations, "final_score": final_eval["avg_score"], "best_score": self.best_score, "best_checkpoint": self.best_checkpoint, "history": self.history, } # Save log log_path = os.path.join(LOGS_DIR, f"improvement_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") with open(log_path, "w") as f: json.dump(result, f, indent=2, default=str) print(f"\n{'='*70}") print("IMPROVEMENT COMPLETE") print(f" Final score: {final_eval['avg_score']:.3f}") print(f" Best score: {self.best_score:.3f}") print(f" Best checkpoint: {self.best_checkpoint}") print(f" Log saved: {log_path}") print("="*70) return result # ============================================================================== # MAIN # ============================================================================== def main(): parser = argparse.ArgumentParser(description="Stable Self-Improvement Training") parser.add_argument("--iterations", type=int, default=5, help="Number of improvement iterations") parser.add_argument("--steps-per-iter", type=int, default=25, help="Training steps per iteration") parser.add_argument("--checkpoint", type=str, default=None, help="Starting checkpoint") parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path") parser.add_argument("--eval-only", action="store_true", help="Only run evaluation") parser.add_argument("--compare", nargs=2, metavar=("CKPT_A", "CKPT_B"), help="Compare two checkpoints") args = parser.parse_args() trainer = SelfImprovementTrainer(args.model_path, args.checkpoint) if args.eval_only: trainer.load_model(args.checkpoint) trainer.evaluate_model() elif args.compare: trainer.load_model(args.compare[0]) trainer.compare_checkpoints(args.compare[0], args.compare[1]) else: trainer.improve(args.iterations, args.steps_per_iter) if __name__ == "__main__": main()