π§ Full weight release: 9 probes Γ 3 architectures + production adapter + training code
297244f
verified
| #!/usr/bin/env python3 | |
| """ | |
| CONTINUE FROM 73.1x CHECKPOINT | |
| ============================ | |
| Loads the successful Qwen checkpoint (73.1x @ step 10000) and continues training. | |
| Target: 100x+ separation | |
| Author: Logan Napolitano / Proprioception AI | |
| Date: February 2026 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel | |
| from datasets import load_dataset | |
| import os | |
| import time | |
| import random | |
| import json | |
| from dataclasses import dataclass, field | |
| from typing import List, Tuple | |
| CHECKPOINT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_19x/final" | |
| OUTPUT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_56x" | |
| class Config: | |
| model_path: str = "Qwen/Qwen2.5-3B" | |
| probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) | |
| d_fiber: int = 16 | |
| d_control: int = 64 | |
| additional_steps: int = 25000 # Continue for 25000 more steps (total 35000) | |
| batch_size: int = 1 | |
| grad_accum: int = 8 | |
| max_length: int = 256 | |
| lr_lora: float = 2e-6 # MUCH lower - model already trained | |
| lr_predictor: float = 1e-5 # MUCH lower - predictor already trained | |
| weight_decay: float = 0.01 | |
| rep_window: int = 32 | |
| log_every: int = 100 | |
| save_every: int = 5000 | |
| eval_every: int = 1000 | |
| class RiskPredictor(nn.Module): | |
| def __init__(self, d_model: int, probe_layers: List[int], d_fiber: int = 16, d_control: int = 64): | |
| super().__init__() | |
| self.probe_layers = probe_layers | |
| n_probes = len(probe_layers) | |
| self.fiber_projs = nn.ModuleList([ | |
| nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_probes) | |
| ]) | |
| self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes) | |
| self.predictor = nn.Sequential( | |
| nn.Linear(d_fiber, d_control), nn.GELU(), | |
| nn.Linear(d_control, d_control), nn.GELU(), | |
| nn.Linear(d_control, 1) | |
| ) | |
| for proj in self.fiber_projs: | |
| nn.init.normal_(proj.weight, std=0.02) | |
| def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor: | |
| fibers = [] | |
| for i, layer_idx in enumerate(self.probe_layers): | |
| if layer_idx < len(hidden_states): | |
| fiber = self.fiber_projs[i](hidden_states[layer_idx].float()) | |
| fibers.append(fiber) | |
| weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) | |
| aggregated = sum(w * f for w, f in zip(weights, fibers)) | |
| return self.predictor(aggregated).squeeze(-1) | |
| def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor: | |
| B, S = input_ids.shape | |
| labels = torch.zeros(B, S, device=input_ids.device) | |
| for offset in range(1, min(window + 1, S)): | |
| if offset < S: | |
| matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float() | |
| labels[:, offset:] = torch.maximum(labels[:, offset:], matches) | |
| return labels | |
| def compute_separation(predictor, model, tokenizer, device, config, n_samples=50): | |
| model.eval() | |
| predictor.eval() | |
| pos_scores, neg_scores = [], [] | |
| prompts = [ | |
| "The meaning of life according to philosophy is", | |
| "In the year 2050, technology will", | |
| "The history of mathematics begins with", | |
| "Climate change affects the planet by", | |
| "Neural networks learn patterns through", | |
| "The ocean contains many species of", | |
| "Music has evolved significantly since", | |
| "Economic theories suggest that markets", | |
| "The human brain processes information", | |
| "Ancient civilizations developed writing", | |
| "The quick brown fox jumps over the lazy", | |
| "Once upon a time in a land far away", | |
| "The scientific method involves several steps", | |
| "When writing code, it is important to", | |
| "In conclusion, we can see that the evidence", | |
| "There are several reasons why this matters", | |
| "Let me explain how this works step by step", | |
| "The main point I want to make is that", | |
| "According to recent research findings", | |
| "One way to look at this problem is", | |
| ] | |
| with torch.no_grad(): | |
| for i in range(n_samples): | |
| prompt = prompts[i % len(prompts)] | |
| inp = tokenizer(prompt, return_tensors='pt') | |
| input_ids = inp['input_ids'].to(device) | |
| attn = inp['attention_mask'].to(device) | |
| # DETERMINISTIC for consistent evaluation | |
| out = model.generate(input_ids, attention_mask=attn, max_new_tokens=80, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id) | |
| outputs = model(out, output_hidden_states=True) | |
| risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy() | |
| labels = compute_repetition_labels(out, config.rep_window)[0].cpu().numpy() | |
| for t in range(len(risk)): | |
| (pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t])) | |
| if pos_scores and neg_scores: | |
| p_pos, p_neg = sum(pos_scores)/len(pos_scores), sum(neg_scores)/len(neg_scores) | |
| return p_pos, p_neg, p_pos/max(p_neg, 1e-8), len(pos_scores), len(neg_scores) | |
| return 0, 0, 0, 0, 0 | |
| def main(): | |
| config = Config() | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading base model...") | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16) | |
| base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True) | |
| print("Loading LoRA weights from checkpoint...") | |
| model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR) | |
| model.train() | |
| # Make LoRA trainable again | |
| for name, param in model.named_parameters(): | |
| if 'lora' in name.lower(): | |
| param.requires_grad = True | |
| device = next(model.parameters()).device | |
| d_model = model.config.hidden_size | |
| print("Loading risk predictor from checkpoint...") | |
| risk_predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control).to(device).float() | |
| ckpt = torch.load(os.path.join(CHECKPOINT_DIR, "risk_predictor.pt"), map_location=device) | |
| risk_predictor.load_state_dict(ckpt['risk_predictor']) | |
| start_step = ckpt['step'] | |
| start_sep = ckpt['separation'] | |
| print() | |
| print("=" * 70) | |
| print("CONTINUING FROM CHECKPOINT (deterministic eval)") | |
| print("=" * 70) | |
| print(f"Starting point: {start_sep:.1f}x separation @ step {start_step}") | |
| print(f"Target: 100x+ separation") | |
| print(f"Additional steps: {config.additional_steps}") | |
| print(f"LR: LoRA={config.lr_lora}, Predictor={config.lr_predictor}") | |
| print() | |
| print("Loading data...") | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") | |
| texts = [ex['text'] for ex in ds if len(ex['text']) > 50] | |
| random.shuffle(texts) | |
| print(f"Loaded {len(texts)} samples") | |
| lora_params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.AdamW([ | |
| {'params': lora_params, 'lr': config.lr_lora}, | |
| {'params': risk_predictor.parameters(), 'lr': config.lr_predictor} | |
| ], weight_decay=config.weight_decay) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=config.additional_steps, eta_min=1e-6) | |
| log = { | |
| "experiment": "continue_from_73x", | |
| "start_step": start_step, | |
| "start_separation": start_sep, | |
| "target": "100x+", | |
| "steps": [], | |
| "separations": [] | |
| } | |
| print() | |
| print("=" * 70) | |
| print("TRAINING") | |
| print("=" * 70) | |
| model.train() | |
| risk_predictor.train() | |
| step = 0 | |
| total_step = start_step | |
| data_idx = 0 | |
| acc_loss, acc_risk = 0, 0 | |
| best_sep = start_sep | |
| start_time = time.time() | |
| while step < config.additional_steps: | |
| batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)] | |
| data_idx += config.batch_size | |
| enc = tokenizer(batch, truncation=True, max_length=config.max_length, | |
| padding='max_length', return_tensors='pt') | |
| input_ids = enc['input_ids'].to(device) | |
| attention_mask = enc['attention_mask'].to(device) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, | |
| labels=input_ids, output_hidden_states=True) | |
| lm_loss = outputs.loss | |
| risk_logits = risk_predictor(outputs.hidden_states) | |
| rep_labels = compute_repetition_labels(input_ids, config.rep_window) | |
| mask = attention_mask.float() | |
| n_pos = (rep_labels * mask).sum().clamp(min=1) | |
| n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1) | |
| pos_weight = (n_neg / n_pos).clamp(max=10.0) | |
| bce = F.binary_cross_entropy_with_logits( | |
| risk_logits, rep_labels, | |
| pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none') | |
| risk_loss = (bce * mask).sum() / mask.sum() | |
| loss = lm_loss + risk_loss | |
| (loss / config.grad_accum).backward() | |
| acc_loss += loss.item() | |
| acc_risk += risk_loss.item() | |
| step += 1 | |
| total_step += 1 | |
| if step % config.grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_(list(lora_params) + list(risk_predictor.parameters()), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| if step % config.log_every == 0: | |
| eta = (config.additional_steps - step) / (step / (time.time() - start_time)) / 60 | |
| print(f"Step {total_step:5d} (+{step}) | Loss: {acc_loss/config.log_every:.3f} | " | |
| f"Risk: {acc_risk/config.log_every:.3f} | Best: {best_sep:.1f}x | ETA: {eta:.1f}m") | |
| log["steps"].append({"step": total_step, "loss": acc_loss/config.log_every}) | |
| acc_loss, acc_risk = 0, 0 | |
| if step % config.eval_every == 0: | |
| print(f"\n{'='*50}") | |
| print(f"SEPARATION EVAL @ Step {total_step}") | |
| print(f"{'='*50}") | |
| p_pos, p_neg, sep, n_p, n_n = compute_separation(risk_predictor, model, tokenizer, device, config) | |
| print(f" P(+) = {p_pos:.4f} (n={n_p})") | |
| print(f" P(-) = {p_neg:.4f} (n={n_n})") | |
| print(f" SEPARATION = {sep:.1f}x") | |
| print(f" [Target: 100x, Best so far: {best_sep:.1f}x]") | |
| log["separations"].append({"step": total_step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg}) | |
| if sep > best_sep: | |
| best_sep = sep | |
| print(f" π― NEW BEST!") | |
| # Save best | |
| best_dir = os.path.join(OUTPUT_DIR, "best") | |
| os.makedirs(best_dir, exist_ok=True) | |
| model.save_pretrained(best_dir) | |
| torch.save({ | |
| 'risk_predictor': risk_predictor.state_dict(), | |
| 'step': total_step, 'separation': sep, 'p_pos': p_pos, 'p_neg': p_neg | |
| }, os.path.join(best_dir, "risk_predictor.pt")) | |
| with open(os.path.join(OUTPUT_DIR, "training_log.json"), 'w') as f: | |
| json.dump(log, f, indent=2) | |
| print(f"{'='*50}\n") | |
| model.train() | |
| risk_predictor.train() | |
| if step % config.save_every == 0: | |
| ckpt_dir = os.path.join(OUTPUT_DIR, f"ckpt_{total_step}") | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| model.save_pretrained(ckpt_dir) | |
| torch.save({ | |
| 'risk_predictor': risk_predictor.state_dict(), | |
| 'step': total_step, 'separation': best_sep | |
| }, os.path.join(ckpt_dir, "risk_predictor.pt")) | |
| print(f">>> Checkpoint saved: {ckpt_dir}") | |
| # Final eval | |
| print("\n" + "=" * 70) | |
| print("FINAL RESULTS") | |
| print("=" * 70) | |
| p_pos, p_neg, final_sep, _, _ = compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=100) | |
| target_hit = "β TARGET HIT!" if final_sep >= 100 else f"Reached {final_sep:.1f}x" | |
| print(f""" | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β CONTINUED TRAINING RESULTS β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ | |
| β Started: 73.1x @ step 10000 β | |
| β Final: {final_sep:>5.1f}x @ step {total_step} β | |
| β Best: {best_sep:>5.1f}x β | |
| β P(+): {p_pos:.4f} β | |
| β P(-): {p_neg:.4f} β | |
| β β | |
| β {target_hit:^54} β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| log["final"] = {"step": total_step, "separation": final_sep, "best": best_sep, "p_pos": p_pos, "p_neg": p_neg} | |
| with open(os.path.join(OUTPUT_DIR, "training_log.json"), 'w') as f: | |
| json.dump(log, f, indent=2) | |
| # Save final | |
| final_dir = os.path.join(OUTPUT_DIR, "final") | |
| os.makedirs(final_dir, exist_ok=True) | |
| model.save_pretrained(final_dir) | |
| torch.save({ | |
| 'risk_predictor': risk_predictor.state_dict(), | |
| 'step': total_step, 'separation': final_sep, 'p_pos': p_pos, 'p_neg': p_neg | |
| }, os.path.join(final_dir, "risk_predictor.pt")) | |
| print(f"Saved to {OUTPUT_DIR}") | |
| print("DONE!") | |
| if __name__ == "__main__": | |
| main() | |