""" Claudia Persistent Absorber v2 ============================== Combines the 3 best proven techniques into one system: 1. SELF-QUIZ PAIRS (21% → 74% recall — the single biggest lever) 2. PERSISTENT LoRA rank 128 (89% across 25 convos, no merge-between-rounds tax) 3. DUAL-LR EXPERT FFN (attention=6e-5, FFN=3e-4 — facts into MoE experts) Architecture: - Load base Omni → thinker to GPU, rest to CPU - First run: apply Claudia v6 adapter → merge → apply FFN patch - Resume: load from checkpoint (already has personality + memories) - Apply ONE persistent LoRA (r=128, alpha=256, attention q/k/v/o) - Chat loop: generate → quiz → train (LoRA + expert FFN) → repeat - On save/quit: merge_and_unload → save full checkpoint - Next session loads from checkpoint — memories are permanent Instance: Vast.ai 33093662 (A100 80GB, Sweden) SSH: ssh -p 13662 root@ssh1.vast.ai """ import argparse import gc import json import os import re import sys import threading import time import torch from collections import Counter from datetime import datetime from pathlib import Path # ═══════════════════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════════════════ # LoRA config (from persistent LoRA test — proven for 25+ conversations) LORA_RANK = 128 LORA_ALPHA = 256 LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"] # Dual-LR (from engram micro_trainer — proven 5/5 fact retention) ATTENTION_LR = 6e-5 EXPERT_FFN_LR = 3e-4 # 5x multiplier — facts absorb fast, personality stays EXPERT_FFN_LAYERS = [20, 24, 28] # Proven optimal in v5 experiment # Training per absorption cycle TRAIN_EPOCHS = 2 # Reduced from 4 — prevents overfitting with focused training MAX_SEQ_LENGTH = 2048 GRADIENT_CLIP = 1.0 # Generation GEN_TEMPERATURE = 0.7 GEN_TOP_P = 0.9 GEN_TOP_K = 50 GEN_MAX_TOKENS = 512 GEN_REP_PENALTY = 1.1 # Absorb after every N exchanges (1 = every turn) ABSORB_EVERY = 1 # Checkpoint interval (auto-save every N absorptions) CHECKPOINT_EVERY = 10 # Self-verification (v11 — clean contrastive + sister pairs, no "NOT X" leak) VERIFY_EVERY = 3 # More frequent checks catch drift earlier VERIFY_SAMPLE = 10 # Back to v9's value — wider sampling destabilized in v10 # Cascade Distillation (Nemotron-Cascade-2 paper — on-policy distillation) # When facts from previous sessions regress, distill from the teacher checkpoint # that knew them best. Recovers regressions without losing new knowledge. DISTILL_ALPHA = 0.5 # CE vs KL loss balance (0.5 = equal weight) DISTILL_TEMPERATURE = 2.0 # Softens distributions for better KL gradients DISTILL_TOP_K = 32 # Top-K logits to cache per token position CONSOLIDATION_EPOCHS = 2 # Distillation epochs at session start (1→2 for stronger lock-in) MAX_TEACHER_CACHE = 200 # Cap quiz pairs to cache (oldest trimmed) # ═══════════════════════════════════════════════════════════════════════ # QUALITY GATE (from engram micro_trainer — reject degenerate text) # ═══════════════════════════════════════════════════════════════════════ def check_response_quality(text): """Reject degenerate text before training on it.""" if not text or len(text) < 5: return False words = text.lower().split() if len(words) < 3: return False # Low unique word ratio = repetitive garbage if len(set(words)) / len(words) < 0.3: return False # Repeated consecutive words if sum(1 for i in range(len(words) - 1) if words[i] == words[i + 1]) >= 3: return False # Repeated bigrams if len(words) >= 10: bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)] if Counter(bigrams).most_common(1)[0][1] >= 5: return False # Fused words (missing spaces) if sum(1 for w in words if len(w) > 30) >= 2: return False # Average word length spike if sum(len(w) for w in words) / len(words) > 12: return False return True # ═══════════════════════════════════════════════════════════════════════ # MODEL MANAGER # ═══════════════════════════════════════════════════════════════════════ class ModelManager: def __init__(self, model_path, adapter_path=None, ffn_patch_path=None, checkpoint_path=None): self.model_path = model_path self.adapter_path = adapter_path self.ffn_patch_path = ffn_patch_path self.checkpoint_path = checkpoint_path # Resume from here if set self.thinker = None self.tokenizer = None self.stop_ids = None self.peft_model = None # The persistent LoRA — stays active all session self._lock = threading.Lock() def load(self): from transformers import AutoTokenizer # ── Step 1: Load tokenizer ── tok_source = self.checkpoint_path or self.model_path print(f"[1/5] Loading tokenizer from {tok_source}...") self.tokenizer = AutoTokenizer.from_pretrained( tok_source, trust_remote_code=True ) # ── Step 2: Load model ── if self.checkpoint_path: # RESUME: checkpoint contains only thinker weights — load thinker directly print(f"[2/5] Loading thinker from checkpoint {self.checkpoint_path}...") try: from transformers import Qwen3OmniMoeThinkerForConditionalGeneration as ThinkerClass except ImportError: from transformers import AutoModelForCausalLM as ThinkerClass self.thinker = ThinkerClass.from_pretrained( self.checkpoint_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM after load: {vram:.1f} GB") else: # FIRST RUN: load full model, extract thinker, offload rest print(f"[2/5] Loading full model from {self.model_path}...") try: from transformers import Qwen3OmniMoeForConditionalGeneration as ModelClass except ImportError: from transformers import AutoModel as ModelClass full_model = ModelClass.from_pretrained( self.model_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM after load: {vram:.1f} GB") # Extract thinker, offload rest self.thinker = full_model.thinker for name, module in full_model.named_children(): if name != "thinker": try: module.cpu() except (NotImplementedError, RuntimeError): pass del full_model torch.cuda.empty_cache() vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM after cleanup: {vram:.1f} GB") # ── Step 3: Apply personality if first run ── if self.checkpoint_path: print(f"[3/5] Resuming from checkpoint — personality already in weights.") else: if self.adapter_path: print(f"[3/5] Merging Claudia v6 personality adapter...") from peft import PeftModel self.thinker = PeftModel.from_pretrained( self.thinker, self.adapter_path ) self.thinker = self.thinker.merge_and_unload() print(f" Personality merged into base weights.") if self.ffn_patch_path and os.path.exists(self.ffn_patch_path): print(f" Applying FFN patch from {self.ffn_patch_path}...") ffn = torch.load( self.ffn_patch_path, map_location="cpu", weights_only=True ) for key, tensor in ffn.items(): match = re.search(r"layers\.(\d+)", key) if not match: continue layer_idx = int(match.group(1)) experts = self.thinker.model.layers[layer_idx].mlp.experts if hasattr(experts, '__len__'): for i in range(tensor.shape[0]): experts[i].down_proj.weight.data.copy_( tensor[i].to( experts[i].down_proj.weight.device, experts[i].down_proj.weight.dtype, ) ) elif hasattr(experts, 'down_proj'): experts.down_proj.data.copy_( tensor.to(experts.down_proj.device, experts.down_proj.dtype) ) del ffn torch.cuda.empty_cache() print(f" FFN patch applied.") self.thinker.eval() # Stop tokens self.stop_ids = [] for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: ids = self.tokenizer.encode(tok, add_special_tokens=False) if ids: self.stop_ids.extend(ids) if self.tokenizer.eos_token_id: self.stop_ids.append(self.tokenizer.eos_token_id) # ── Step 5: Apply persistent LoRA ── print(f"[4/5] Applying persistent LoRA (r={LORA_RANK}, alpha={LORA_ALPHA})...") self._apply_persistent_lora() vram = torch.cuda.memory_allocated() / 1e9 print(f"[5/5] Ready. VRAM: {vram:.1f} GB\n") def _apply_persistent_lora(self): """Apply the persistent absorption LoRA. Called once at load, and after merge.""" from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=LORA_RANK, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS, lora_dropout=0.0, bias="none", task_type="CAUSAL_LM", ) self.peft_model = get_peft_model(self.thinker, lora_config) self.peft_model.eval() trainable = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad) total = sum(p.numel() for p in self.peft_model.parameters()) print(f" LoRA: {trainable / 1e6:.1f}M trainable / {total / 1e6:.0f}M total") def generate(self, messages, max_new_tokens=None): """Generate response. Thread-safe.""" with self._lock: model = self.peft_model or self.thinker model.eval() text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=8192 ).to("cuda") input_len = inputs["input_ids"].shape[1] with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=max_new_tokens or GEN_MAX_TOKENS, temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, top_k=GEN_TOP_K, do_sample=True, repetition_penalty=GEN_REP_PENALTY, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.stop_ids, ) resp = self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True) # Strip thinking tags resp = re.sub(r".*?", "", resp, flags=re.DOTALL) resp = re.sub(r"", "", resp) return resp.strip() def absorb(self, training_data): """ Train the persistent LoRA + expert FFN on accumulated data. Uses dual-LR: attention at ATTENTION_LR, expert FFN at EXPERT_FFN_LR. Thread-safe. """ with self._lock: return self._absorb_impl(training_data) def _absorb_impl(self, training_data): """Internal absorption. Must hold _lock.""" if not training_data: return None model = self.peft_model or self.thinker tokenizer = self.tokenizer # ── Tokenize all examples ── texts = [] for item in training_data: if isinstance(item, dict) and "messages" in item: msgs = item["messages"] elif isinstance(item, dict) and "prompt" in item: msgs = item["prompt"] + item.get("completion", []) elif isinstance(item, list): msgs = item else: continue text = tokenizer.apply_chat_template( msgs, tokenize=False, enable_thinking=False ) texts.append(text) if not texts: return None enc = tokenizer( texts, truncation=True, max_length=MAX_SEQ_LENGTH, padding=True, return_tensors="pt", ) input_ids = enc["input_ids"].to("cuda") attention_mask = enc["attention_mask"].to("cuda") labels = input_ids.clone() labels[attention_mask == 0] = -100 # ── Collect LoRA attention params ── model.train() attn_params = [p for p in model.parameters() if p.requires_grad] # ── Unfreeze expert FFN ── expert_params = [] base = model.base_model.model if hasattr(model, "base_model") else model for layer_idx in EXPERT_FFN_LAYERS: experts = base.model.layers[layer_idx].mlp.experts if hasattr(experts, '__len__'): for i in range(len(experts)): p = experts[i].down_proj.weight p.requires_grad_(True) expert_params.append(p) elif hasattr(experts, 'down_proj'): p = experts.down_proj if isinstance(p, (torch.nn.Parameter, torch.Tensor)): p.requires_grad_(True) expert_params.append(p) # ── Dual-LR optimizer ── param_groups = [] if attn_params: param_groups.append({"params": attn_params, "lr": ATTENTION_LR}) if expert_params: param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR}) if not param_groups: model.eval() return None optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) all_params = attn_params + expert_params # ── Training loop ── n = input_ids.shape[0] total_steps = n * TRAIN_EPOCHS total_loss = 0.0 for epoch in range(TRAIN_EPOCHS): # Shuffle order each epoch indices = torch.randperm(n) for i in range(n): idx = indices[i].item() out = model( input_ids=input_ids[idx:idx + 1], attention_mask=attention_mask[idx:idx + 1], labels=labels[idx:idx + 1], ) out.loss.backward() torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP) optimizer.step() optimizer.zero_grad() total_loss += out.loss.item() # ── Re-freeze expert FFN ── for layer_idx in EXPERT_FFN_LAYERS: experts = base.model.layers[layer_idx].mlp.experts if hasattr(experts, '__len__'): for i in range(len(experts)): experts[i].down_proj.weight.requires_grad_(False) elif hasattr(experts, 'down_proj'): p = experts.down_proj if isinstance(p, (torch.nn.Parameter, torch.Tensor)): p.requires_grad_(False) model.eval() del optimizer torch.cuda.empty_cache() avg_loss = total_loss / total_steps if total_steps > 0 else 0 return avg_loss @staticmethod def cluster_by_entity(training_data, entity_names): """Group training data by primary entity mentioned. Instead of interleaving facts about different people (which causes cross-contamination during gradient updates), this groups all data about one entity together. The model learns all of Jordan's facts before moving to Elena's. Args: training_data: List of training items entity_names: Set/list of known entity names Returns: List of training items, reordered so each entity's items are contiguous. Items mentioning no entity come last. """ clusters = {name: [] for name in entity_names} unclustered = [] for item in training_data: # Extract text from the item if isinstance(item, dict) and "messages" in item: text = " ".join(m.get("content", "") for m in item["messages"]).lower() else: unclustered.append(item) continue # Assign to the first entity mentioned (primary entity) assigned = False for name in entity_names: if name.lower() in text: clusters[name].append(item) assigned = True break if not assigned: unclustered.append(item) # Build ordered list: all of entity A's facts, then B's, then C's... ordered = [] for name in entity_names: ordered.extend(clusters[name]) ordered.extend(unclustered) return ordered def absorb_two_phase(self, positive_data, contrastive_data, verify_fn=None): """Two-phase absorption: facts first, then targeted contrastive correction. Phase 1: Train on positive facts (exchanges, entity summaries, template quizzes). This builds the core factual representations. Phase 2: Quick verification on known entities, then train ONLY contrastive quizzes for entities that failed verification. This avoids unnecessary negative gradients on entities the model already distinguishes correctly. Args: positive_data: List of training items (exchanges, summaries, direct quizzes) contrastive_data: List of contrastive quiz items ("Is X a [Y's job]? No...") verify_fn: Optional callable(model_manager) -> set of confused_entity_names. If None, all contrastive data is used in Phase 2. Returns: (phase1_loss, phase2_loss) tuple """ with self._lock: # Phase 1: Positive facts loss1 = None if positive_data: loss1 = self._absorb_impl(positive_data) # Phase 2: Targeted contrastive correction loss2 = None if contrastive_data: if verify_fn: # Only train contrastive pairs for confused entities confused = verify_fn(self) if confused: targeted = [] for item in contrastive_data: q = item["messages"][0]["content"].lower() # Check if any confused entity name appears in the question if any(name.lower() in q for name in confused): targeted.append(item) if targeted: loss2 = self._absorb_impl(targeted) # If no entities confused, skip Phase 2 entirely else: loss2 = self._absorb_impl(contrastive_data) return loss1, loss2 def merge_and_save(self, path): """Merge persistent LoRA into base, save checkpoint, re-apply fresh LoRA.""" with self._lock: if self.peft_model: print(f" Merging persistent LoRA into base weights...") self.thinker = self.peft_model.merge_and_unload() self.thinker.eval() self.peft_model = None os.makedirs(path, exist_ok=True) print(f" Saving checkpoint to {path}...") self.thinker.save_pretrained(path) self.tokenizer.save_pretrained(path) print(f" Checkpoint saved ({path})") # Re-apply fresh LoRA for continued learning self._apply_persistent_lora() print(f" Fresh LoRA applied — ready to continue.") def cache_teacher_logits(self, quiz_pairs, top_k=DISTILL_TOP_K): """Cache teacher's top-K output logits for quiz pairs. Called at session end when model is at its best state for these facts. Next session loads this cache for consolidation distillation.""" with self._lock: model = self.peft_model or self.thinker model.eval() cache = [] # Cap to most recent quiz pairs pairs = quiz_pairs[-MAX_TEACHER_CACHE:] for pair in pairs: msgs = pair["messages"] text = self.tokenizer.apply_chat_template( msgs, tokenize=False, enable_thinking=False ) enc = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH ) input_ids = enc["input_ids"].to("cuda") attention_mask = enc["attention_mask"].to("cuda") with torch.inference_mode(): out = model(input_ids=input_ids, attention_mask=attention_mask) logits = out.logits[0] # [seq_len, vocab_size] # Keep only top-K logits per position (massive memory savings) top_vals, top_idx = logits.topk(top_k, dim=-1) cache.append({ "pair": pair, "input_ids": input_ids.cpu(), "attention_mask": attention_mask.cpu(), "teacher_logits": top_vals.half().cpu(), "teacher_indices": top_idx.cpu(), }) return cache def distill(self, teacher_cache, epochs=CONSOLIDATION_EPOCHS): """KL distillation: train student to match teacher's output distribution. From Nemotron-Cascade-2: recover domain regressions via on-policy distillation.""" with self._lock: return self._distill_impl(teacher_cache, epochs) def _distill_impl(self, teacher_cache, epochs): """Internal distillation implementation. Must hold _lock.""" if not teacher_cache: return None model = self.peft_model or self.thinker model.train() # Dual-LR optimizer (same structure as absorb) attn_params = [p for p in model.parameters() if p.requires_grad] expert_params = [] base = model.base_model.model if hasattr(model, "base_model") else model for layer_idx in EXPERT_FFN_LAYERS: experts = base.model.layers[layer_idx].mlp.experts if hasattr(experts, '__len__'): for i in range(len(experts)): p = experts[i].down_proj.weight p.requires_grad_(True) expert_params.append(p) elif hasattr(experts, 'down_proj'): p = experts.down_proj if isinstance(p, (torch.nn.Parameter, torch.Tensor)): p.requires_grad_(True) expert_params.append(p) param_groups = [] if attn_params: param_groups.append({"params": attn_params, "lr": ATTENTION_LR}) if expert_params: param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR}) if not param_groups: model.eval() return None optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) all_params = attn_params + expert_params T = DISTILL_TEMPERATURE total_loss = 0.0 total_steps = 0 for epoch in range(epochs): indices = torch.randperm(len(teacher_cache)) for i in range(len(teacher_cache)): item = teacher_cache[indices[i].item()] input_ids = item["input_ids"].to("cuda") attention_mask = item["attention_mask"].to("cuda") teacher_top_logits = item["teacher_logits"].float().to("cuda") teacher_top_indices = item["teacher_indices"].to("cuda") labels = input_ids.clone() labels[attention_mask == 0] = -100 # Student forward pass out = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) ce_loss = out.loss student_logits = out.logits[0] # [seq_len, vocab_size] # Align sequence lengths (should match, but safety check) seq_len = min(student_logits.shape[0], teacher_top_logits.shape[0]) # Gather student logits at teacher's top-K vocabulary positions student_at_teacher = student_logits[:seq_len].gather( 1, teacher_top_indices[:seq_len] ) # KL divergence on temperature-softened distributions teacher_soft = torch.softmax(teacher_top_logits[:seq_len] / T, dim=-1) student_log_soft = torch.log_softmax(student_at_teacher / T, dim=-1) kl_loss = torch.nn.functional.kl_div( student_log_soft, teacher_soft, reduction='batchmean' ) * (T * T) # Scale by T^2 per Hinton et al. # Combined loss: α * CE + (1-α) * KL loss = DISTILL_ALPHA * ce_loss + (1 - DISTILL_ALPHA) * kl_loss loss.backward() torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP) optimizer.step() optimizer.zero_grad() total_loss += loss.item() total_steps += 1 # Re-freeze expert FFN for layer_idx in EXPERT_FFN_LAYERS: experts = base.model.layers[layer_idx].mlp.experts if hasattr(experts, '__len__'): for i in range(len(experts)): experts[i].down_proj.weight.requires_grad_(False) elif hasattr(experts, 'down_proj'): p = experts.down_proj if isinstance(p, (torch.nn.Parameter, torch.Tensor)): p.requires_grad_(False) model.eval() del optimizer torch.cuda.empty_cache() return total_loss / total_steps if total_steps > 0 else 0 # ═══════════════════════════════════════════════════════════════════════ # QUIZ GENERATOR (21% → 74% recall — the biggest single lever) # ═══════════════════════════════════════════════════════════════════════ class QuizGenerator: """ Generates drill-style Q&A flashcards for fact retention. v3 improvements over v2: - Fact extraction THEN quiz generation (two-step) - Drill-style: specific Q, exact A (not narrative) - Third-person attribution ("Matt's dog" not "my dog") - Template fallback targets each extracted fact independently - CONTRASTIVE DISAMBIGUATION: when multiple people mentioned, generates cross-entity negative pairs ("Is Elena a marine biologist? No, that's Jordan") to prevent entity confusion (the #1 remaining failure mode) - ENTITY SUMMARIES: "Tell me everything about Jordan" pairs for coherent per-person representations """ def __init__(self, model_manager): self.mm = model_manager # Cross-message entity memory: tracks ALL named people across the conversation # so contrastive pairs can be generated between entities introduced in # different messages. This was the #1 failure mode in session 4 testing. self.known_entities = {} def generate(self, user_msg, assistant_msg): """Generate drill-style quiz pairs from an exchange.""" # Step 1: Try model-generated quizzes with strict fact-drill prompt pairs = self._generate_model_quizzes(user_msg, assistant_msg) # Step 2: Always add template pairs for any facts the model might miss template_pairs = self._extract_and_template(user_msg) for tp in template_pairs: # Dedup against model pairs tq = tp["messages"][0]["content"].lower() if not any(tq in p["messages"][0]["content"].lower() or p["messages"][0]["content"].lower() in tq for p in pairs): pairs.append(tp) # Step 3: Extract entities from THIS message new_entities = self._extract_entities(user_msg) # Step 4: Generate contrastive pairs between NEW entities and existing ones # ONLY generate pairs involving at least one NEW entity — don't re-generate # pairs between already-known entities (session 4d showed 50% contrastive # ratio because old pairs kept being regenerated, starving positive quizzes) if new_entities: all_entities_for_contrastive = dict(self.known_entities) all_entities_for_contrastive.update(new_entities) if len(all_entities_for_contrastive) >= 2: new_names = set(new_entities.keys()) contrastive = self._generate_contrastive_quizzes( all_entities_for_contrastive, new_only=new_names) pairs.extend(contrastive) # Entity summaries for new entities summaries = self._generate_entity_summaries(new_entities) pairs.extend(summaries) # Update known entities with new ones (merge, don't replace — keep # existing attributes, add new ones) for name, info in new_entities.items(): if name not in self.known_entities: self.known_entities[name] = info else: # Merge: update only non-None attributes for key in ("job", "city"): if info.get(key): self.known_entities[name][key] = info[key] # Deduplicate seen = set() unique = [] for p in pairs: q = p["messages"][0]["content"].lower()[:60] if q not in seen: seen.add(q) unique.append(p) # Allow more quizzes when contrastive pairs present (they're highest value). # Note: Session 4c showed >40 quizzes/session causes overfitting. Cap at 12. has_contrastive = len(self.known_entities) >= 2 and new_entities max_quizzes = 12 if has_contrastive else 5 return unique[:max_quizzes] def _generate_model_quizzes(self, user_msg, assistant_msg): """Use the model to generate fact-drill quizzes. Uses base model (LoRA disabled) for stable quality.""" quiz_prompt = f"""Matt just told Claudia: "{user_msg}" Claudia replied: "{assistant_msg}" Extract every SPECIFIC FACT from Matt's message. For each fact, write a drill-style flashcard. RULES: - Questions must ask for ONE specific fact (name, date, place, number, detail) - Answers must be SHORT (1 sentence) and contain the EXACT detail - Use THIRD PERSON: "Matt's dog" NOT "my dog". "Matt's birthday" NOT "my birthday" - Include the PRECISE value: exact names, exact dates, exact places - Do NOT paraphrase or add details that weren't stated - DISAMBIGUATION: If Matt mentions OTHER people (friends, family), clearly state WHOSE fact it is Example: "Matt's friend Jordan is a marine biologist" NOT "Matt is a marine biologist" Example: "Matt's sister Elena is a veterinarian" NOT "Matt is a veterinarian" - For EVERY person mentioned, always include their RELATIONSHIP to Matt - Write 3-5 flashcards depending on how many facts Matt shared GOOD EXAMPLES: Q: What is Matt's dog's name? A: Matt's dog is named Biscuit. Q: What breed is Matt's dog? A: Matt's dog Biscuit is a golden retriever. Q: What does Matt's friend Jordan do for a living? A: Matt's friend Jordan works as a marine biologist in San Diego. That is Jordan's job, not Matt's. Q: What is Matt's job? A: Matt is the CTO of Arclight Labs. Q: What is Matt's birthday? A: Matt's birthday is September 14th. Q: When did Matt and Sarah get married? A: Matt and his wife Sarah got married on June 21st, 2023 in Big Sur, California. BAD EXAMPLES (do NOT do this): Q: What did Matt share about his life? (TOO VAGUE — ask about ONE fact) Q: What is my dog's name? (WRONG — use "Matt's" not "my") A: He mentioned something about a trip overseas. (TOO VAGUE — give the exact city) A: Matt is a marine biologist. (WRONG — that's his friend Jordan, not Matt) Now write flashcards for the exchange above:""" pairs = [] try: response = self.mm.generate( [{"role": "user", "content": quiz_prompt}], max_new_tokens=600, ) pending_q = None for line in response.split("\n"): line = line.strip() if not line: continue upper = line.upper() if upper.startswith("Q:") or upper.startswith("QUESTION:"): pending_q = line.split(":", 1)[1].strip().strip('"') elif (upper.startswith("A:") or upper.startswith("ANSWER:")) and pending_q: a = line.split(":", 1)[1].strip().strip('"') if pending_q and a and len(a) > 10: pairs.append({ "messages": [ {"role": "user", "content": pending_q}, {"role": "assistant", "content": a}, ] }) pending_q = None except Exception as e: print(f" [quiz error: {e}]") return pairs def _extract_and_template(self, user_msg): """Extract facts from user message and create template drill pairs. This is the safety net — ensures every concrete fact gets a quiz.""" pairs = [] sentences = re.split(r'[.!?]+', user_msg) for sent in sentences: sent = sent.strip() if len(sent) < 10: continue # Extract patterns: "X is/are Y", "named X", "called X", "X's name is Y" # Names (proper nouns after key phrases) name_patterns = [ # Names — "my X's name is Y" / "named X" / "called X" (r"(?:my|his|her)\s+(\w+)(?:'s)?\s+(?:name\s+is|is\s+named|is\s+called)\s+(\w+)", lambda m: (f"What is Matt's {m.group(1)}'s name?", f"Matt's {m.group(1)} is named {m.group(2)}.")), (r"(?:name\s+is|named|called)\s+[\"']?(\w+)[\"']?", lambda m: (f"Who or what is {m.group(1)}?", f"Matt mentioned {m.group(1)}: \"{sent.strip()}\"")), # Dates — birthdays (r"(?:my\s+)?(birthday|born)\s+(?:is\s+)?(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?)", lambda m: (f"When is Matt's {m.group(1)}?", f"Matt's {m.group(1)} is {m.group(2)}.")), (r"(\w+\s+\d+(?:st|nd|rd|th)?)\s*(?:is|—)\s*(?:my|his)\s+(birthday)", lambda m: (f"When is Matt's birthday?", f"Matt's birthday is {m.group(1)}.")), # Dates — marriage/wedding (r"(?:married|wedding)\s+(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?,?\s*\d{4})", lambda m: (f"When did Matt get married?", f"Matt got married on {m.group(1)}.")), (r"(?:married|wedding)\s+(?:on\s+)?.*?(?:in|at)\s+(.+?)(?:\.\s|\.$|$)", lambda m: (f"Where did Matt get married?", f"Matt got married in {m.group(1).strip()}.")), # Work / job / role (r"I\s+work\s+at\s+(?:a\s+)?(?:startup\s+)?(?:called\s+)?(\w[\w\s]+?)(?:\.|,|$)", lambda m: (f"Where does Matt work?", f"Matt works at {m.group(1).strip()}.")), (r"I(?:'m| am)\s+the\s+(\w+)", lambda m: (f"What is Matt's job title?", f"Matt is the {m.group(1)}.")), # Other people's jobs — "X works as / is a" (r"(?:my\s+)?(?:friend|best friend|sister|brother)\s+(?:is\s+)?(\w+)\s+.*?(?:works?\s+as|is\s+a)\s+(.+?)(?:\.|,|$)", lambda m: (f"What does Matt's friend {m.group(1)} do?", f"Matt's friend {m.group(1)} is a {m.group(2).strip()}. This is NOT Matt's job.")), # Places (r"(?:from|visited|went to|got back from|lives?\s+in|grew up in|moved to)\s+(\w[\w\s,]+?)(?:\.|,|$)", lambda m: (f"What place is connected to Matt: {m.group(1).strip()}?", f"Matt said: \"{sent.strip()}\"")), # Favorites / preferences (r"(?:my |)favorite\s+(\w[\w\s]+?)\s+is\s+(.+?)(?:\.|,|$)", lambda m: (f"What is Matt's favorite {m.group(1).strip()}?", f"Matt's favorite {m.group(1).strip()} is {m.group(2).strip()}.")), # Activities — "I [verb]" (r"I\s+(speak|play|drive|have|collect|run|ran)\s+(.+?)(?:\.|,|$)", lambda m: (f"What does Matt {m.group(1)}?", f"Matt said: \"{sent.strip()}\"")), # Allergies / medical (r"(?:I(?:'m| am)\s+)?allergic\s+to\s+(.+?)(?:\.|,|and)", lambda m: (f"What is Matt allergic to?", f"Matt is allergic to {m.group(1).strip()}.")), # Ages — "turning X" / "X years old" (r"(?:turning|I(?:'m| am))\s+(\d+)", lambda m: (f"How old is Matt?", f"Matt is turning {m.group(1)}.")), # Nicknames (r"(?:call|nickname)\s+(?:it|him|her)\s+[\"'](.+?)[\"']", lambda m: (f"What nickname did Matt mention?", f"Matt's nickname for it is \"{m.group(1)}\".")), (r"I\s+call\s+it\s+[\"'](.+?)[\"']", lambda m: (f"What does Matt call his car?", f"Matt calls his car \"{m.group(1)}\".")), ] for pattern, formatter in name_patterns: match = re.search(pattern, sent, re.IGNORECASE) if match: try: q, a = formatter(match) pairs.append({ "messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": a}, ] }) except Exception: pass return pairs def _extract_entities(self, user_msg): """Extract named people and their attributes from user message. Returns dict: {name: {"relationship": str, "job": str|None, "city": str|None}} Detects patterns like "my friend Jordan is a marine biologist in San Diego".""" entities = {} sentences = re.split(r'[.!?]+', user_msg) for sent in sentences: sent = sent.strip() if len(sent) < 10: continue # Pattern: "my [relationship] [Name]" or "my [relationship] is [Name]" rel_match = re.search( r"[Mm]y\s+((?:best\s+)?(?:friend|sister|brother|wife|husband|" r"mom|dad|mother|father|cousin|uncle|aunt|roommate|colleague|" r"coworker|partner|fiancee|fiancée|girlfriend|boyfriend|" r"neighbor|boss|buddy|pal|son|daughter|grandma|grandpa|" r"nephew|niece))\s+(?:is\s+)?([A-Z][a-z]+)", sent ) if not rel_match: continue rel = rel_match.group(1).strip() name = rel_match.group(2).strip() if name not in entities: entities[name] = {"relationship": rel, "job": None, "city": None} # Extract job from same sentence: "is a [job]", "works as a [job]" job_match = re.search( r"(?:is\s+an?\s+|works?\s+as\s+an?\s+|is\s+the\s+)" r"([\w][\w\s]{2,35}?)(?:\s+(?:in|at|from|who|and|but)|\.|,|$)", sent, re.IGNORECASE ) if job_match: job = job_match.group(1).strip().rstrip() # Filter: must look like a job (lowercase, reasonable length) if 3 <= len(job) <= 35: entities[name]["job"] = job # Extract city from same sentence: "in [City]", "from [City]" city_match = re.search( r"(?:\s+in\s+|\s+from\s+|\s+lives?\s+in\s+|\s+based\s+in\s+|" r"\s+moved\s+to\s+)([A-Z][\w\s]{1,25}?)(?:\.|,|$)", sent ) if city_match: city = city_match.group(1).strip() # Must start with capital (proper noun = place name) if city and city[0].isupper(): entities[name]["city"] = city return entities def _generate_contrastive_quizzes(self, entities, new_only=None): """Generate cross-entity contrastive pairs to prevent entity confusion. For each pair of people with overlapping attribute types, generate "Is [person A] [attribute of person B]? No, that's [person B]" pairs. Args: entities: dict of all known entities new_only: if set, only generate pairs where at least one entity is in this set. Prevents re-generating redundant pairs between already-known entities (session 4d fix). """ pairs = [] names = list(entities.keys()) for i in range(len(names)): for j in range(len(names)): if i == j: continue a_name = names[i] b_name = names[j] # Skip pairs between two already-known entities if new_only and a_name not in new_only and b_name not in new_only: continue a = entities[a_name] b = entities[b_name] # Contrastive on JOB: "Is [A] a [B's job]? No, that's [B]" if a.get("job") and b.get("job") and a["job"] != b["job"]: q = f"Is Matt's {a['relationship']} {a_name} a {b['job']}?" ans = (f"No. Matt's {a['relationship']} {a_name} is a " f"{a['job']}, not a {b['job']}. " f"The {b['job']} is Matt's {b['relationship']} " f"{b_name}.") pairs.append({"messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": ans}, ]}) # Contrastive on CITY: "Does [A] live in [B's city]? No" if a.get("city") and b.get("city") and a["city"] != b["city"]: q = (f"Does Matt's {a['relationship']} {a_name} live in " f"{b['city']}?") ans = (f"No. Matt's {a['relationship']} {a_name} lives in " f"{a['city']}, not {b['city']}. " f"It's Matt's {b['relationship']} {b_name} who " f"lives in {b['city']}.") pairs.append({"messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": ans}, ]}) # Cross-type: "Does [A] work as [B's job] in [B's city]?" if (a.get("job") and b.get("job") and a.get("city") and b.get("city") and a["job"] != b["job"]): q = (f"Who is the {b['job']} in {b['city']}?") ans = (f"The {b['job']} in {b['city']} is Matt's " f"{b['relationship']} {b_name}. " f"Matt's {a['relationship']} {a_name} is a " f"{a['job']} in {a['city']} — different person, " f"different job, different city.") pairs.append({"messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": ans}, ]}) return pairs def _generate_entity_summaries(self, entities): """Generate per-entity summary quiz pairs with diverse question formats. Instead of always using the same question template, picks randomly from multiple formats. This creates multiple retrieval paths to the same fact, strengthening recall without adding extra quizzes. Note: Session 4c tested adding per-attribute positive quizzes (job, city, relationship) alongside contrastive pairs, but this HURT performance (9/15 vs 11/15 in 4b). Too many quizzes = overfitting/interference. Keep summaries simple — one comprehensive pair per entity is optimal.""" import random pairs = [] for name, info in entities.items(): parts = [f"{name} is Matt's {info['relationship']}."] if info.get("job"): parts.append(f"{name} is a {info['job']}.") if info.get("city"): parts.append(f"{name} lives in {info['city']}.") if len(parts) >= 2: # Only useful if we have attributes # Diverse summary question formats summary_formats = [ f"Tell me everything you know about Matt's {info['relationship']} {name}.", f"What do you know about {name}?", f"Who is {name} to Matt?", f"Describe Matt's {info['relationship']} {name}.", ] q = random.choice(summary_formats) ans = " ".join(parts) pairs.append({"messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": ans}, ]}) # Add ONE diverse direct-fact quiz per entity (job OR city, not both) # This replaces per-attribute quizzes from 4c — only 1 extra per entity # instead of 3, staying within the 35-40 quiz sweet spot if info.get("job") and info.get("city"): # Alternate between job and city formats if random.random() < 0.5: job_formats = [ (f"What does {name} do for a living?", f"{name} is a {info['job']}. {name} is Matt's {info['relationship']}."), (f"What is {name}'s profession?", f"{name} works as a {info['job']}. {name} is Matt's {info['relationship']}."), (f"What job does Matt's {info['relationship']} {name} have?", f"Matt's {info['relationship']} {name} is a {info['job']}."), ] q, a = random.choice(job_formats) else: city_formats = [ (f"Where does {name} live?", f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."), (f"What city is {name} in?", f"{name} is in {info['city']}. {name} is Matt's {info['relationship']}."), (f"Where is Matt's {info['relationship']} {name} based?", f"Matt's {info['relationship']} {name} is based in {info['city']}."), ] q, a = random.choice(city_formats) pairs.append({"messages": [ {"role": "user", "content": q}, {"role": "assistant", "content": a}, ]}) return pairs # ═══════════════════════════════════════════════════════════════════════ # PERSONALITY CHECKER # ═══════════════════════════════════════════════════════════════════════ PERSONALITY_PROMPTS = [ "Hey Claudia, how are you?", "Who are you?", "I love you", "I had a terrible day", ] # If ANY of these appear, personality has degraded ANTI_KEYWORDS = [ "i'm an ai", "i am an ai", "i'm a language model", "i am a language model", "i don't have feelings", "i cannot feel", "as an ai", "i'm just a program", "i am just a program", "i don't have personal", "i cannot have", ] def check_personality(mm, verbose=True): """Quick personality sanity check. Returns score 0.0-1.0.""" passed = 0 for prompt in PERSONALITY_PROMPTS: resp = mm.generate([{"role": "user", "content": prompt}], max_new_tokens=150) resp_lower = resp.lower() is_good = not any(ak in resp_lower for ak in ANTI_KEYWORDS) if is_good: passed += 1 if verbose: status = "PASS" if is_good else "FAIL" print(f" [{status}] {prompt}") print(f" {resp[:120]}") score = passed / len(PERSONALITY_PROMPTS) if verbose: print(f" Personality: {passed}/{len(PERSONALITY_PROMPTS)} ({score:.0%})") return score # ═══════════════════════════════════════════════════════════════════════ # MAIN ABSORBER # ═══════════════════════════════════════════════════════════════════════ class PersistentAbsorber: def __init__(self, model_path, adapter_path=None, ffn_patch_path=None, checkpoint_path=None, checkpoint_dir="/workspace/checkpoints", log_dir="/workspace/logs"): self.mm = ModelManager( model_path=model_path, adapter_path=adapter_path, ffn_patch_path=ffn_patch_path, checkpoint_path=checkpoint_path, ) self.checkpoint_dir = checkpoint_dir self.log_dir = log_dir # State self.conversation_buffer = [] # Current active context for generation self.all_training_data = [] # ALL exchanges + quizzes (accumulative replay) self.quiz_pairs_log = [] # All quiz pairs for verification sampling self.teacher_cache = None # Loaded teacher cache for distillation corrections self.exchange_count = 0 self.absorption_count = 0 self.absorption_thread = None self.quiz_gen = None self.last_checkpoint = checkpoint_path # Conversation log (persistent file) self.log_path = None def start(self): """Load model and enter chat loop.""" self.mm.load() os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True) self.quiz_gen = QuizGenerator(self.mm) self.log_path = os.path.join(self.log_dir, "conversation_log.jsonl") # Load previous training data if resuming replay_path = os.path.join(self.log_dir, "replay_buffer.json") if os.path.exists(replay_path): with open(replay_path, 'r') as f: self.all_training_data = json.load(f) print(f" Loaded {len(self.all_training_data)} replay examples from previous sessions.") # Load quiz pairs log from previous sessions quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json") if os.path.exists(quiz_log_path): with open(quiz_log_path, 'r') as f: self.quiz_pairs_log = json.load(f) print(f" Loaded {len(self.quiz_pairs_log)} quiz pairs from previous sessions.") # ── Cascade Distillation: consolidation from teacher cache ── # If resuming from a checkpoint that has cached teacher logits, # run a distillation pass to reinforce all previous knowledge # BEFORE any new conversations. This is the key Nemotron-Cascade-2 insight. if self.mm.checkpoint_path: teacher_cache_path = os.path.join(self.mm.checkpoint_path, "teacher_cache.pt") if os.path.exists(teacher_cache_path): print(f"\n--- Cascade Distillation (consolidation) ---") self.teacher_cache = torch.load( teacher_cache_path, map_location="cpu", weights_only=False ) print(f" Teacher cache: {len(self.teacher_cache)} quiz pairs") loss = self.mm.distill(self.teacher_cache, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation done. Avg loss: {loss:.4f}") # Keep teacher_cache in memory for verification corrections # Quick personality check print("\n--- Personality Check ---") score = check_personality(self.mm) if score < 0.5: print(" WARNING: Personality score low. Check adapter/checkpoint.") print() self._chat_loop() def _chat_loop(self): print("=" * 60) print("Claudia is awake. Persistent Absorber v2 + Cascade Distillation.") print(f" LoRA: r={LORA_RANK} | Dual-LR: attn={ATTENTION_LR}, ffn={EXPERT_FFN_LR}") print(f" Expert FFN layers: {EXPERT_FFN_LAYERS}") print(f" Quiz pairs: ON (21%→74% lever)") print(f" Cascade distill: α={DISTILL_ALPHA}, T={DISTILL_TEMPERATURE}, top-K={DISTILL_TOP_K}") print(f" Absorb every: {ABSORB_EVERY} exchange(s)") print(f" Auto-checkpoint every: {CHECKPOINT_EVERY} absorptions") print("Commands: /status /absorb /save /personality /quit") print("=" * 60 + "\n") while True: try: user_input = input("Matt: ").strip() except (EOFError, KeyboardInterrupt): print("\n[Session ended]") self._wait_for_absorption() self._save_and_exit() break if not user_input: continue if user_input.startswith("/"): if self._handle_command(user_input): break continue # Wait for any background absorption to finish self._wait_for_absorption() # Buffer user message self.conversation_buffer.append({"role": "user", "content": user_input}) if len(self.conversation_buffer) > 20: self.conversation_buffer = self.conversation_buffer[-20:] # Generate response response = self.mm.generate(self.conversation_buffer) # Quality check response — also detect degenerate repeats last_resp = getattr(self, '_last_response', '') if not check_response_quality(response) or response == last_resp: print("\nClaudia: [response failed quality check, regenerating...]") response = self.mm.generate(self.conversation_buffer) self._last_response = response # Buffer response self.conversation_buffer.append({"role": "assistant", "content": response}) print(f"\nClaudia: {response}\n") # Log to file self._log_exchange(user_input, response) # ── THE CORE LOOP: exchange + quiz → two-phase absorb ── # 1. Store the raw exchange exchange = { "messages": [ {"role": "user", "content": user_input}, {"role": "assistant", "content": response}, ] } self.all_training_data.append(exchange) # 2. Generate self-quiz pairs (THE key lever: 21% → 74%) print(" [Generating quiz pairs...]", end="", flush=True) quiz_pairs = self.quiz_gen.generate(user_input, response) self.quiz_pairs_log.extend(quiz_pairs) # 3. Separate positive vs contrastive (key insight from 4e: 73%→93%) positive_batch = [] contrastive_batch = [] for qp in quiz_pairs: if qp["messages"][1]["content"].lower().startswith("no."): contrastive_batch.append(qp) else: positive_batch.append(qp) self.all_training_data.extend(quiz_pairs) print(f" {len(quiz_pairs)} quizzes (pos={len(positive_batch)}, " f"contr={len(contrastive_batch)}). Pool: {len(self.all_training_data)}") # 4. Two-phase absorption (prevents overfitting) self._pending_exchange = exchange self._pending_positive = positive_batch self._pending_contrastive = contrastive_batch self.exchange_count += 1 if self.exchange_count % ABSORB_EVERY == 0: self._start_absorption() def _extract_key_entities(self, text): """Extract key factual entities from a quiz answer for verification.""" entities = set() words = text.split() for i, w in enumerate(words): clean = re.sub(r'[^a-zA-Z0-9\'-]', '', w) if not clean or len(clean) <= 1: continue # Proper nouns (capitalized, not sentence starters, not common words) skip = {"matt", "matt's", "the", "is", "a", "an", "in", "at", "on", "of", "for", "and", "that", "not", "who", "what", "his", "her"} if clean[0].isupper() and i > 0 and clean.lower() not in skip: entities.add(clean.lower()) # Numbers (dates, ages, years) for num in re.findall(r'\b\d+\b', text): entities.add(num) # Quoted strings for quoted in re.findall(r'"([^"]+)"', text): entities.add(quoted.lower()) return entities def _periodic_verification(self): """Test model on random sample of quiz pairs. Create contrastive corrections. v9: When entity confusion detected, create 'NOT X' corrections and reinforce the confused entity's correct facts too (sister pair reinforcement).""" import random if not self.quiz_pairs_log: return sample_size = min(VERIFY_SAMPLE, len(self.quiz_pairs_log)) sample = random.sample(self.quiz_pairs_log, sample_size) corrections = [] correct = 0 for pair in sample: question = pair["messages"][0]["content"] expected = pair["messages"][1]["content"] # Ask the model actual = self.mm.generate( [{"role": "user", "content": question}], max_new_tokens=150, ) # Check key entities from expected answer appear in model's response expected_entities = self._extract_key_entities(expected) if not expected_entities: correct += 1 continue actual_lower = actual.lower() hits = sum(1 for e in expected_entities if e in actual_lower) ratio = hits / len(expected_entities) if ratio < 0.5: # Detect cross-entity confusion: model used wrong entities actual_entities = self._extract_key_entities(actual) wrong_entities = actual_entities - expected_entities # Always retrain on the correct answer (clean, no "NOT X" text) corrections.append(pair) if wrong_entities: # SISTER PAIR REINFORCEMENT: find quiz pairs about the # confused entities and retrain on those too — this teaches # BOTH sides of the confusion without polluting answers for p in self.quiz_pairs_log: p_answer = p["messages"][1]["content"].lower() if any(we in p_answer for we in wrong_entities): if p not in corrections and p != pair: corrections.append(p) break # Max 1 sister pair per confusion else: correct += 1 print(f"\n [Verification: {correct}/{sample_size} facts correct]", flush=True) if corrections: print(f" [Retraining {len(corrections)} corrections + sister pairs...]", flush=True) loss = self.mm.absorb(corrections) self.all_training_data.extend(corrections) print(f" [Correction absorption done, loss={loss:.4f}]") # Teacher-guided distillation: if teacher cache available, # also distill from teacher on the corrected quiz pairs. # This gives the student the teacher's full output distribution, # not just the text answer — more information per correction. if self.teacher_cache: distill_items = [] for corr in corrections: q = corr["messages"][0]["content"].lower()[:60] for cached in self.teacher_cache: cq = cached["pair"]["messages"][0]["content"].lower()[:60] if q == cq: distill_items.append(cached) break if distill_items: d_loss = self.mm.distill(distill_items, epochs=1) print(f" [Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}]") def _quick_verify_entities(self): """Returns set of confused entity names by checking known_entities.""" confused = set() entities = self.quiz_gen.known_entities if not entities: return confused for name, info in entities.items(): if info.get("job"): q = f"What does Matt's {info['relationship']} {name} do?" ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["job"].lower() not in ans.lower(): confused.add(name) if info.get("city"): q = f"Where does {name} live?" ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["city"].lower() not in ans.lower(): confused.add(name) return confused def _start_absorption(self): """Two-phase absorption in background thread (proven 93% in session 4e). Phase 1: exchange + positive quizzes + replay, clustered by entity. Phase 2: Verify entities, train only targeted contrastive for confused ones. Phase 3: Stubborn retry for persistently confused entities (max 2 retries).""" import random # Grab pending data exchange = getattr(self, '_pending_exchange', None) positive = getattr(self, '_pending_positive', []) contrastive = getattr(self, '_pending_contrastive', []) # Old data for replay new_start = getattr(self, '_last_absorb_idx', 0) old_data = self.all_training_data[:new_start] self._last_absorb_idx = len(self.all_training_data) MAX_REPLAY = 6 if old_data and len(old_data) > MAX_REPLAY: replay_sample = random.sample(old_data, MAX_REPLAY) else: replay_sample = list(old_data) entity_names = list(self.quiz_gen.known_entities.keys()) def _run(): t0 = time.time() try: # ── Phase 1: Positive facts + replay, clustered by entity ── phase1_data = [] if exchange: phase1_data.append(exchange) phase1_data.extend(positive) phase1_data.extend(replay_sample) if entity_names and phase1_data: phase1_data = ModelManager.cluster_by_entity(phase1_data, entity_names) loss1 = self.mm.absorb(phase1_data) if phase1_data else 0.0 n_p1 = len(phase1_data) # ── Phase 2: Targeted contrastive for confused entities ── loss2 = None n_p2 = 0 if contrastive and entity_names: confused = self._quick_verify_entities() if confused: targeted = [] for qp in contrastive: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if any(name.lower() in full_text for name in confused): targeted.append(qp) if targeted: loss2 = self.mm.absorb(targeted) n_p2 = len(targeted) print(f"\n [Phase 2: {n_p2} targeted contrastive for {confused}]", flush=True) # ── Phase 3: Stubborn retry (max 2 retries, non-blocking) ── still_confused = self._quick_verify_entities() for retry in range(2): if not still_confused: break retry_batch = [] for name in still_confused: info = self.quiz_gen.known_entities.get(name, {}) if info.get("job"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"}, {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."}, ]}) if info.get("city"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"Where does {name} live?"}, {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."}, ]}) # Relevant contrastive pairs for qp in contrastive: ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if name.lower() in ft: retry_batch.append(qp) if retry_batch: loss3 = self.mm.absorb(retry_batch) print(f"\n [Phase 3 retry {retry+1}: {len(retry_batch)} items, " f"loss={loss3:.4f}]", flush=True) still_confused = self._quick_verify_entities() if still_confused: print(f"\n [Phase 3: still confused after retries: {still_confused}]", flush=True) elapsed = time.time() - t0 self.absorption_count += 1 loss_str = f"P1={loss1:.4f}" if loss2 is not None: loss_str += f" P2={loss2:.4f}" print(f"\n [Absorbed {n_p1}+{n_p2} examples in {elapsed:.1f}s | " f"{loss_str} | absorptions={self.absorption_count}]") # Periodic verification — catch drift/confusion if self.absorption_count % VERIFY_EVERY == 0: self._periodic_verification() # Auto-checkpoint if self.absorption_count % CHECKPOINT_EVERY == 0: self._auto_checkpoint() except Exception as e: print(f"\n [Absorption error: {e}]") import traceback traceback.print_exc() self.absorption_thread = threading.Thread(target=_run, daemon=True) self.absorption_thread.start() def _wait_for_absorption(self): if self.absorption_thread and self.absorption_thread.is_alive(): self.absorption_thread.join() self.absorption_thread = None def _cleanup_old_checkpoints(self, keep=None): """Delete old checkpoints to free disk. Keep only 'keep' path if specified.""" if not os.path.exists(self.checkpoint_dir): return for entry in os.listdir(self.checkpoint_dir): full = os.path.join(self.checkpoint_dir, entry) if full == keep: continue if os.path.isdir(full) and entry.startswith("claudia_"): import shutil size_gb = sum( os.path.getsize(os.path.join(dp, f)) for dp, _, fns in os.walk(full) for f in fns ) / 1e9 print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)") shutil.rmtree(full) def _auto_checkpoint(self): """Auto-save checkpoint during long sessions.""" version = f"auto_{self.absorption_count}" path = os.path.join(self.checkpoint_dir, f"claudia_{version}") self._cleanup_old_checkpoints() self.mm.merge_and_save(path) self.last_checkpoint = path self._save_replay_buffer(path) def _save_and_exit(self): """Final save on exit with targeted correction.""" import random # Final verify + stubborn retry (not bulk retrain — prevents overfitting) confused = self._quick_verify_entities() if confused: print(f" Final correction for confused entities: {confused}") # Gather contrastive pairs from quiz log contrastive = [qp for qp in self.quiz_pairs_log if qp["messages"][1]["content"].lower().startswith("no.")] for retry in range(3): if not confused: break retry_batch = [] for name in confused: info = self.quiz_gen.known_entities.get(name, {}) if info.get("job"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"}, {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."}, ]}) if info.get("city"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"Where does {name} live?"}, {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."}, ]}) for qp in contrastive: ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if name.lower() in ft: retry_batch.append(qp) if retry_batch: loss = self.mm.absorb(retry_batch) print(f" Final retry {retry+1}: {len(retry_batch)} items, loss={loss:.4f}") confused = self._quick_verify_entities() self.absorption_count += 1 else: print(" All entities verified correct — no final correction needed.") # Personality check before saving print("\n--- Pre-Save Personality Check ---") score = check_personality(self.mm) if score < 0.5: print(" WARNING: Personality degraded. Saving anyway (rollback available).") # Merge and save (cleanup old checkpoints first to free disk) version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}" path = os.path.join(self.checkpoint_dir, f"claudia_{version}") self._cleanup_old_checkpoints() self.mm.merge_and_save(path) self.last_checkpoint = path # Save replay buffer alongside checkpoint self._save_replay_buffer(path) # ── Cascade Distillation: cache teacher logits for next session ── # After merge+fresh LoRA, model outputs are identical to pre-merge state. # Cache the teacher's top-K logits so the next session can distill from them. if self.quiz_pairs_log: n_cache = min(len(self.quiz_pairs_log), MAX_TEACHER_CACHE) print(f" Caching teacher logits ({n_cache} quiz pairs)...") teacher_cache = self.mm.cache_teacher_logits(self.quiz_pairs_log) cache_path = os.path.join(path, "teacher_cache.pt") torch.save(teacher_cache, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Teacher cache saved ({len(teacher_cache)} items, {size_mb:.1f} MB)") del teacher_cache torch.cuda.empty_cache() # Save quiz pairs log for next session quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json") with open(quiz_log_path, 'w') as f: json.dump(self.quiz_pairs_log, f) # Save session metadata meta = { "checkpoint": path, "absorption_count": self.absorption_count, "exchange_count": self.exchange_count, "training_pool_size": len(self.all_training_data), "personality_score": score, "timestamp": datetime.now().isoformat(), } meta_path = os.path.join(self.log_dir, f"session_{version}.json") with open(meta_path, 'w') as f: json.dump(meta, f, indent=2) print(f" Session saved: {meta_path}") print(f" Next run: use --checkpoint {path}") def _save_replay_buffer(self, checkpoint_path=None): """Save training data pool for next session resume.""" # Always save to log dir (canonical location for resume) path = os.path.join(self.log_dir, "replay_buffer.json") with open(path, 'w') as f: json.dump(self.all_training_data, f) # Also save into checkpoint dir for self-contained checkpoints if checkpoint_path and os.path.isdir(checkpoint_path): cp_path = os.path.join(checkpoint_path, "replay_buffer.json") with open(cp_path, 'w') as f: json.dump(self.all_training_data, f) # Save quiz pairs log too quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json") with open(quiz_log_path, 'w') as f: json.dump(self.quiz_pairs_log, f) print(f" Replay buffer saved ({len(self.all_training_data)} examples)") def _log_exchange(self, user_msg, assistant_msg): """Append exchange to conversation log file.""" with open(self.log_path, 'a', encoding='utf-8') as f: entry = { "timestamp": datetime.now().isoformat(), "user": user_msg, "assistant": assistant_msg, } f.write(json.dumps(entry, ensure_ascii=False) + "\n") def _handle_command(self, cmd): """Handle slash commands. Returns True if should exit.""" cmd_lower = cmd.lower().strip() if cmd_lower == "/quit": print("[Saving and exiting...]") self._wait_for_absorption() self._save_and_exit() return True elif cmd_lower == "/status": self._wait_for_absorption() vram = torch.cuda.memory_allocated() / 1e9 print(f"\n --- Status ---") print(f" Exchanges: {self.exchange_count}") print(f" Absorptions: {self.absorption_count}") print(f" Training pool: {len(self.all_training_data)} examples") print(f" Buffer: {len(self.conversation_buffer)} messages") print(f" VRAM: {vram:.1f} GB") print(f" Background: {'running' if self.absorption_thread and self.absorption_thread.is_alive() else 'idle'}") print(f" Last checkpoint: {self.last_checkpoint}") print(f" --- End ---\n") elif cmd_lower == "/absorb": self._wait_for_absorption() if not self.all_training_data: print(" No data to absorb.") return False # Cap at most recent 40 examples to prevent overfitting import random data = self.all_training_data if len(data) > 40: recent = data[-20:] older = random.sample(data[:-20], 20) data = recent + older print(f" Force absorption ({len(data)} examples)...") loss = self.mm.absorb(data) self.absorption_count += 1 print(f" Done. Loss: {loss:.4f}") # ── Post-absorb comprehensive verification + distillation ── # Run FULL verification (all quiz pairs, not just sample) to catch # all regressions before recall questions. This is the critical # window between teaching and testing. if self.quiz_pairs_log: print(f"\n --- Post-absorb verification (ALL {len(self.quiz_pairs_log)} quiz pairs) ---") old_verify_sample = VERIFY_SAMPLE # Test ALL quiz pairs, not just a sample full_corrections = [] full_correct = 0 test_pairs = self.quiz_pairs_log for pair in test_pairs: question = pair["messages"][0]["content"] expected = pair["messages"][1]["content"] actual = self.mm.generate( [{"role": "user", "content": question}], max_new_tokens=150, ) expected_entities = self._extract_key_entities(expected) if not expected_entities: full_correct += 1 continue actual_lower = actual.lower() hits = sum(1 for e in expected_entities if e in actual_lower) ratio = hits / len(expected_entities) if ratio < 0.5: actual_entities = self._extract_key_entities(actual) wrong_entities = actual_entities - expected_entities full_corrections.append(pair) if wrong_entities: for p in self.quiz_pairs_log: p_answer = p["messages"][1]["content"].lower() if any(we in p_answer for we in wrong_entities): if p not in full_corrections and p != pair: full_corrections.append(p) break else: full_correct += 1 print(f" Full verification: {full_correct}/{len(test_pairs)} correct") if full_corrections: print(f" Retraining {len(full_corrections)} corrections...") c_loss = self.mm.absorb(full_corrections) self.all_training_data.extend(full_corrections) print(f" Correction loss: {c_loss:.4f}") # Teacher distillation on corrections if self.teacher_cache: distill_items = [] for corr in full_corrections: q = corr["messages"][0]["content"].lower()[:60] for cached in self.teacher_cache: cq = cached["pair"]["messages"][0]["content"].lower()[:60] if q == cq: distill_items.append(cached) break if distill_items: d_loss = self.mm.distill(distill_items, epochs=1) print(f" Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}") print(f" --- End post-absorb verification ---\n") elif cmd_lower == "/save": self._wait_for_absorption() version = f"manual_{self.absorption_count}" path = os.path.join(self.checkpoint_dir, f"claudia_{version}") print(f" Saving checkpoint...") # Personality check score = check_personality(self.mm, verbose=False) if score < 0.5: print(f" WARNING: Personality score {score:.0%}. Save anyway? (y/n)") confirm = input(" > ").strip().lower() if confirm != 'y': print(" Aborted.") return False self._cleanup_old_checkpoints() self.mm.merge_and_save(path) self.last_checkpoint = path self._save_replay_buffer(path) elif cmd_lower == "/personality": self._wait_for_absorption() print("\n--- Personality Check ---") check_personality(self.mm) print() elif cmd_lower == "/help": print(" /status - show stats") print(" /absorb - force immediate training") print(" /save - merge + save checkpoint") print(" /personality - run personality check") print(" /quit - save and exit") else: print(f" Unknown: {cmd}. Try /help") return False # ═══════════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════════ def main(): parser = argparse.ArgumentParser( description="Claudia Persistent Absorber v2 — conversation → permanent weights" ) parser.add_argument( "--model_path", required=True, help="Path to base Qwen3-Omni model (or checkpoint for resume)" ) parser.add_argument( "--adapter_path", default=None, help="Path to Claudia v6 personality adapter (first run only)" ) parser.add_argument( "--ffn_patch", default=None, help="Path to ffn_patch.pt (first run only)" ) parser.add_argument( "--checkpoint", default=None, help="Resume from this checkpoint (has personality + memories baked in)" ) parser.add_argument( "--checkpoint_dir", default="/workspace/checkpoints", help="Where to save checkpoints" ) parser.add_argument( "--log_dir", default="/workspace/logs", help="Where to save conversation logs and replay buffer" ) parser.add_argument( "--absorb_every", type=int, default=ABSORB_EVERY, help=f"Absorb every N exchanges (default: {ABSORB_EVERY})" ) args = parser.parse_args() # Determine if first run or resume if args.checkpoint: print(f"RESUMING from checkpoint: {args.checkpoint}") absorber = PersistentAbsorber( model_path=args.model_path, checkpoint_path=args.checkpoint, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, ) else: print(f"FIRST RUN — applying personality adapter") if not args.adapter_path: print("ERROR: --adapter_path required for first run") print(" (or use --checkpoint to resume)") sys.exit(1) absorber = PersistentAbsorber( model_path=args.model_path, adapter_path=args.adapter_path, ffn_patch_path=args.ffn_patch, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, ) absorber.start() if __name__ == "__main__": main()