claudia-memory-pipeline / persistent_absorber.py
msrcam's picture
Upload persistent_absorber.py with huggingface_hub
7b9779e verified
"""
Claudia Persistent Absorber v2
==============================
Combines the 3 best proven techniques into one system:
1. SELF-QUIZ PAIRS (21% β†’ 74% recall β€” the single biggest lever)
2. PERSISTENT LoRA rank 128 (89% across 25 convos, no merge-between-rounds tax)
3. DUAL-LR EXPERT FFN (attention=6e-5, FFN=3e-4 β€” facts into MoE experts)
Architecture:
- Load base Omni β†’ thinker to GPU, rest to CPU
- First run: apply Claudia v6 adapter β†’ merge β†’ apply FFN patch
- Resume: load from checkpoint (already has personality + memories)
- Apply ONE persistent LoRA (r=128, alpha=256, attention q/k/v/o)
- Chat loop: generate β†’ quiz β†’ train (LoRA + expert FFN) β†’ repeat
- On save/quit: merge_and_unload β†’ save full checkpoint
- Next session loads from checkpoint β€” memories are permanent
Instance: Vast.ai 33093662 (A100 80GB, Sweden)
SSH: ssh -p 13662 root@ssh1.vast.ai
"""
import argparse
import gc
import json
import os
import re
import sys
import threading
import time
import torch
from collections import Counter
from datetime import datetime
from pathlib import Path
# ═══════════════════════════════════════════════════════════════════════
# CONFIG
# ═══════════════════════════════════════════════════════════════════════
# LoRA config (from persistent LoRA test β€” proven for 25+ conversations)
LORA_RANK = 128
LORA_ALPHA = 256
LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
# Dual-LR (from engram micro_trainer β€” proven 5/5 fact retention)
ATTENTION_LR = 6e-5
EXPERT_FFN_LR = 3e-4 # 5x multiplier β€” facts absorb fast, personality stays
EXPERT_FFN_LAYERS = [20, 24, 28] # Proven optimal in v5 experiment
# Training per absorption cycle
TRAIN_EPOCHS = 2 # Reduced from 4 β€” prevents overfitting with focused training
MAX_SEQ_LENGTH = 2048
GRADIENT_CLIP = 1.0
# Generation
GEN_TEMPERATURE = 0.7
GEN_TOP_P = 0.9
GEN_TOP_K = 50
GEN_MAX_TOKENS = 512
GEN_REP_PENALTY = 1.1
# Absorb after every N exchanges (1 = every turn)
ABSORB_EVERY = 1
# Checkpoint interval (auto-save every N absorptions)
CHECKPOINT_EVERY = 10
# Self-verification (v11 β€” clean contrastive + sister pairs, no "NOT X" leak)
VERIFY_EVERY = 3 # More frequent checks catch drift earlier
VERIFY_SAMPLE = 10 # Back to v9's value β€” wider sampling destabilized in v10
# Cascade Distillation (Nemotron-Cascade-2 paper β€” on-policy distillation)
# When facts from previous sessions regress, distill from the teacher checkpoint
# that knew them best. Recovers regressions without losing new knowledge.
DISTILL_ALPHA = 0.5 # CE vs KL loss balance (0.5 = equal weight)
DISTILL_TEMPERATURE = 2.0 # Softens distributions for better KL gradients
DISTILL_TOP_K = 32 # Top-K logits to cache per token position
CONSOLIDATION_EPOCHS = 2 # Distillation epochs at session start (1β†’2 for stronger lock-in)
MAX_TEACHER_CACHE = 200 # Cap quiz pairs to cache (oldest trimmed)
# ═══════════════════════════════════════════════════════════════════════
# QUALITY GATE (from engram micro_trainer β€” reject degenerate text)
# ═══════════════════════════════════════════════════════════════════════
def check_response_quality(text):
"""Reject degenerate text before training on it."""
if not text or len(text) < 5:
return False
words = text.lower().split()
if len(words) < 3:
return False
# Low unique word ratio = repetitive garbage
if len(set(words)) / len(words) < 0.3:
return False
# Repeated consecutive words
if sum(1 for i in range(len(words) - 1) if words[i] == words[i + 1]) >= 3:
return False
# Repeated bigrams
if len(words) >= 10:
bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
if Counter(bigrams).most_common(1)[0][1] >= 5:
return False
# Fused words (missing spaces)
if sum(1 for w in words if len(w) > 30) >= 2:
return False
# Average word length spike
if sum(len(w) for w in words) / len(words) > 12:
return False
return True
# ═══════════════════════════════════════════════════════════════════════
# MODEL MANAGER
# ═══════════════════════════════════════════════════════════════════════
class ModelManager:
def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
checkpoint_path=None):
self.model_path = model_path
self.adapter_path = adapter_path
self.ffn_patch_path = ffn_patch_path
self.checkpoint_path = checkpoint_path # Resume from here if set
self.thinker = None
self.tokenizer = None
self.stop_ids = None
self.peft_model = None # The persistent LoRA β€” stays active all session
self._lock = threading.Lock()
def load(self):
from transformers import AutoTokenizer
# ── Step 1: Load tokenizer ──
tok_source = self.checkpoint_path or self.model_path
print(f"[1/5] Loading tokenizer from {tok_source}...")
self.tokenizer = AutoTokenizer.from_pretrained(
tok_source, trust_remote_code=True
)
# ── Step 2: Load model ──
if self.checkpoint_path:
# RESUME: checkpoint contains only thinker weights β€” load thinker directly
print(f"[2/5] Loading thinker from checkpoint {self.checkpoint_path}...")
try:
from transformers import Qwen3OmniMoeThinkerForConditionalGeneration as ThinkerClass
except ImportError:
from transformers import AutoModelForCausalLM as ThinkerClass
self.thinker = ThinkerClass.from_pretrained(
self.checkpoint_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after load: {vram:.1f} GB")
else:
# FIRST RUN: load full model, extract thinker, offload rest
print(f"[2/5] Loading full model from {self.model_path}...")
try:
from transformers import Qwen3OmniMoeForConditionalGeneration as ModelClass
except ImportError:
from transformers import AutoModel as ModelClass
full_model = ModelClass.from_pretrained(
self.model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after load: {vram:.1f} GB")
# Extract thinker, offload rest
self.thinker = full_model.thinker
for name, module in full_model.named_children():
if name != "thinker":
try:
module.cpu()
except (NotImplementedError, RuntimeError):
pass
del full_model
torch.cuda.empty_cache()
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after cleanup: {vram:.1f} GB")
# ── Step 3: Apply personality if first run ──
if self.checkpoint_path:
print(f"[3/5] Resuming from checkpoint β€” personality already in weights.")
else:
if self.adapter_path:
print(f"[3/5] Merging Claudia v6 personality adapter...")
from peft import PeftModel
self.thinker = PeftModel.from_pretrained(
self.thinker, self.adapter_path
)
self.thinker = self.thinker.merge_and_unload()
print(f" Personality merged into base weights.")
if self.ffn_patch_path and os.path.exists(self.ffn_patch_path):
print(f" Applying FFN patch from {self.ffn_patch_path}...")
ffn = torch.load(
self.ffn_patch_path, map_location="cpu", weights_only=True
)
for key, tensor in ffn.items():
match = re.search(r"layers\.(\d+)", key)
if not match:
continue
layer_idx = int(match.group(1))
experts = self.thinker.model.layers[layer_idx].mlp.experts
if hasattr(experts, '__len__'):
for i in range(tensor.shape[0]):
experts[i].down_proj.weight.data.copy_(
tensor[i].to(
experts[i].down_proj.weight.device,
experts[i].down_proj.weight.dtype,
)
)
elif hasattr(experts, 'down_proj'):
experts.down_proj.data.copy_(
tensor.to(experts.down_proj.device, experts.down_proj.dtype)
)
del ffn
torch.cuda.empty_cache()
print(f" FFN patch applied.")
self.thinker.eval()
# Stop tokens
self.stop_ids = []
for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
ids = self.tokenizer.encode(tok, add_special_tokens=False)
if ids:
self.stop_ids.extend(ids)
if self.tokenizer.eos_token_id:
self.stop_ids.append(self.tokenizer.eos_token_id)
# ── Step 5: Apply persistent LoRA ──
print(f"[4/5] Applying persistent LoRA (r={LORA_RANK}, alpha={LORA_ALPHA})...")
self._apply_persistent_lora()
vram = torch.cuda.memory_allocated() / 1e9
print(f"[5/5] Ready. VRAM: {vram:.1f} GB\n")
def _apply_persistent_lora(self):
"""Apply the persistent absorption LoRA. Called once at load, and after merge."""
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=LORA_RANK,
lora_alpha=LORA_ALPHA,
target_modules=LORA_TARGETS,
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
)
self.peft_model = get_peft_model(self.thinker, lora_config)
self.peft_model.eval()
trainable = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.peft_model.parameters())
print(f" LoRA: {trainable / 1e6:.1f}M trainable / {total / 1e6:.0f}M total")
def generate(self, messages, max_new_tokens=None):
"""Generate response. Thread-safe."""
with self._lock:
model = self.peft_model or self.thinker
model.eval()
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=8192
).to("cuda")
input_len = inputs["input_ids"].shape[1]
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens or GEN_MAX_TOKENS,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
top_k=GEN_TOP_K,
do_sample=True,
repetition_penalty=GEN_REP_PENALTY,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.stop_ids,
)
resp = self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
# Strip thinking tags
resp = re.sub(r"<think>.*?</think>", "", resp, flags=re.DOTALL)
resp = re.sub(r"</?think>", "", resp)
return resp.strip()
def absorb(self, training_data):
"""
Train the persistent LoRA + expert FFN on accumulated data.
Uses dual-LR: attention at ATTENTION_LR, expert FFN at EXPERT_FFN_LR.
Thread-safe.
"""
with self._lock:
return self._absorb_impl(training_data)
def _absorb_impl(self, training_data):
"""Internal absorption. Must hold _lock."""
if not training_data:
return None
model = self.peft_model or self.thinker
tokenizer = self.tokenizer
# ── Tokenize all examples ──
texts = []
for item in training_data:
if isinstance(item, dict) and "messages" in item:
msgs = item["messages"]
elif isinstance(item, dict) and "prompt" in item:
msgs = item["prompt"] + item.get("completion", [])
elif isinstance(item, list):
msgs = item
else:
continue
text = tokenizer.apply_chat_template(
msgs, tokenize=False, enable_thinking=False
)
texts.append(text)
if not texts:
return None
enc = tokenizer(
texts,
truncation=True,
max_length=MAX_SEQ_LENGTH,
padding=True,
return_tensors="pt",
)
input_ids = enc["input_ids"].to("cuda")
attention_mask = enc["attention_mask"].to("cuda")
labels = input_ids.clone()
labels[attention_mask == 0] = -100
# ── Collect LoRA attention params ──
model.train()
attn_params = [p for p in model.parameters() if p.requires_grad]
# ── Unfreeze expert FFN ──
expert_params = []
base = model.base_model.model if hasattr(model, "base_model") else model
for layer_idx in EXPERT_FFN_LAYERS:
experts = base.model.layers[layer_idx].mlp.experts
if hasattr(experts, '__len__'):
for i in range(len(experts)):
p = experts[i].down_proj.weight
p.requires_grad_(True)
expert_params.append(p)
elif hasattr(experts, 'down_proj'):
p = experts.down_proj
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
p.requires_grad_(True)
expert_params.append(p)
# ── Dual-LR optimizer ──
param_groups = []
if attn_params:
param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
if expert_params:
param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
if not param_groups:
model.eval()
return None
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
all_params = attn_params + expert_params
# ── Training loop ──
n = input_ids.shape[0]
total_steps = n * TRAIN_EPOCHS
total_loss = 0.0
for epoch in range(TRAIN_EPOCHS):
# Shuffle order each epoch
indices = torch.randperm(n)
for i in range(n):
idx = indices[i].item()
out = model(
input_ids=input_ids[idx:idx + 1],
attention_mask=attention_mask[idx:idx + 1],
labels=labels[idx:idx + 1],
)
out.loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
optimizer.step()
optimizer.zero_grad()
total_loss += out.loss.item()
# ── Re-freeze expert FFN ──
for layer_idx in EXPERT_FFN_LAYERS:
experts = base.model.layers[layer_idx].mlp.experts
if hasattr(experts, '__len__'):
for i in range(len(experts)):
experts[i].down_proj.weight.requires_grad_(False)
elif hasattr(experts, 'down_proj'):
p = experts.down_proj
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
p.requires_grad_(False)
model.eval()
del optimizer
torch.cuda.empty_cache()
avg_loss = total_loss / total_steps if total_steps > 0 else 0
return avg_loss
@staticmethod
def cluster_by_entity(training_data, entity_names):
"""Group training data by primary entity mentioned.
Instead of interleaving facts about different people (which causes
cross-contamination during gradient updates), this groups all data
about one entity together. The model learns all of Jordan's facts
before moving to Elena's.
Args:
training_data: List of training items
entity_names: Set/list of known entity names
Returns: List of training items, reordered so each entity's items
are contiguous. Items mentioning no entity come last.
"""
clusters = {name: [] for name in entity_names}
unclustered = []
for item in training_data:
# Extract text from the item
if isinstance(item, dict) and "messages" in item:
text = " ".join(m.get("content", "") for m in item["messages"]).lower()
else:
unclustered.append(item)
continue
# Assign to the first entity mentioned (primary entity)
assigned = False
for name in entity_names:
if name.lower() in text:
clusters[name].append(item)
assigned = True
break
if not assigned:
unclustered.append(item)
# Build ordered list: all of entity A's facts, then B's, then C's...
ordered = []
for name in entity_names:
ordered.extend(clusters[name])
ordered.extend(unclustered)
return ordered
def absorb_two_phase(self, positive_data, contrastive_data, verify_fn=None):
"""Two-phase absorption: facts first, then targeted contrastive correction.
Phase 1: Train on positive facts (exchanges, entity summaries, template quizzes).
This builds the core factual representations.
Phase 2: Quick verification on known entities, then train ONLY contrastive
quizzes for entities that failed verification. This avoids unnecessary
negative gradients on entities the model already distinguishes correctly.
Args:
positive_data: List of training items (exchanges, summaries, direct quizzes)
contrastive_data: List of contrastive quiz items ("Is X a [Y's job]? No...")
verify_fn: Optional callable(model_manager) -> set of confused_entity_names.
If None, all contrastive data is used in Phase 2.
Returns: (phase1_loss, phase2_loss) tuple
"""
with self._lock:
# Phase 1: Positive facts
loss1 = None
if positive_data:
loss1 = self._absorb_impl(positive_data)
# Phase 2: Targeted contrastive correction
loss2 = None
if contrastive_data:
if verify_fn:
# Only train contrastive pairs for confused entities
confused = verify_fn(self)
if confused:
targeted = []
for item in contrastive_data:
q = item["messages"][0]["content"].lower()
# Check if any confused entity name appears in the question
if any(name.lower() in q for name in confused):
targeted.append(item)
if targeted:
loss2 = self._absorb_impl(targeted)
# If no entities confused, skip Phase 2 entirely
else:
loss2 = self._absorb_impl(contrastive_data)
return loss1, loss2
def merge_and_save(self, path):
"""Merge persistent LoRA into base, save checkpoint, re-apply fresh LoRA."""
with self._lock:
if self.peft_model:
print(f" Merging persistent LoRA into base weights...")
self.thinker = self.peft_model.merge_and_unload()
self.thinker.eval()
self.peft_model = None
os.makedirs(path, exist_ok=True)
print(f" Saving checkpoint to {path}...")
self.thinker.save_pretrained(path)
self.tokenizer.save_pretrained(path)
print(f" Checkpoint saved ({path})")
# Re-apply fresh LoRA for continued learning
self._apply_persistent_lora()
print(f" Fresh LoRA applied β€” ready to continue.")
def cache_teacher_logits(self, quiz_pairs, top_k=DISTILL_TOP_K):
"""Cache teacher's top-K output logits for quiz pairs.
Called at session end when model is at its best state for these facts.
Next session loads this cache for consolidation distillation."""
with self._lock:
model = self.peft_model or self.thinker
model.eval()
cache = []
# Cap to most recent quiz pairs
pairs = quiz_pairs[-MAX_TEACHER_CACHE:]
for pair in pairs:
msgs = pair["messages"]
text = self.tokenizer.apply_chat_template(
msgs, tokenize=False, enable_thinking=False
)
enc = self.tokenizer(
text, return_tensors="pt", truncation=True,
max_length=MAX_SEQ_LENGTH
)
input_ids = enc["input_ids"].to("cuda")
attention_mask = enc["attention_mask"].to("cuda")
with torch.inference_mode():
out = model(input_ids=input_ids, attention_mask=attention_mask)
logits = out.logits[0] # [seq_len, vocab_size]
# Keep only top-K logits per position (massive memory savings)
top_vals, top_idx = logits.topk(top_k, dim=-1)
cache.append({
"pair": pair,
"input_ids": input_ids.cpu(),
"attention_mask": attention_mask.cpu(),
"teacher_logits": top_vals.half().cpu(),
"teacher_indices": top_idx.cpu(),
})
return cache
def distill(self, teacher_cache, epochs=CONSOLIDATION_EPOCHS):
"""KL distillation: train student to match teacher's output distribution.
From Nemotron-Cascade-2: recover domain regressions via on-policy distillation."""
with self._lock:
return self._distill_impl(teacher_cache, epochs)
def _distill_impl(self, teacher_cache, epochs):
"""Internal distillation implementation. Must hold _lock."""
if not teacher_cache:
return None
model = self.peft_model or self.thinker
model.train()
# Dual-LR optimizer (same structure as absorb)
attn_params = [p for p in model.parameters() if p.requires_grad]
expert_params = []
base = model.base_model.model if hasattr(model, "base_model") else model
for layer_idx in EXPERT_FFN_LAYERS:
experts = base.model.layers[layer_idx].mlp.experts
if hasattr(experts, '__len__'):
for i in range(len(experts)):
p = experts[i].down_proj.weight
p.requires_grad_(True)
expert_params.append(p)
elif hasattr(experts, 'down_proj'):
p = experts.down_proj
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
p.requires_grad_(True)
expert_params.append(p)
param_groups = []
if attn_params:
param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
if expert_params:
param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
if not param_groups:
model.eval()
return None
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
all_params = attn_params + expert_params
T = DISTILL_TEMPERATURE
total_loss = 0.0
total_steps = 0
for epoch in range(epochs):
indices = torch.randperm(len(teacher_cache))
for i in range(len(teacher_cache)):
item = teacher_cache[indices[i].item()]
input_ids = item["input_ids"].to("cuda")
attention_mask = item["attention_mask"].to("cuda")
teacher_top_logits = item["teacher_logits"].float().to("cuda")
teacher_top_indices = item["teacher_indices"].to("cuda")
labels = input_ids.clone()
labels[attention_mask == 0] = -100
# Student forward pass
out = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
ce_loss = out.loss
student_logits = out.logits[0] # [seq_len, vocab_size]
# Align sequence lengths (should match, but safety check)
seq_len = min(student_logits.shape[0], teacher_top_logits.shape[0])
# Gather student logits at teacher's top-K vocabulary positions
student_at_teacher = student_logits[:seq_len].gather(
1, teacher_top_indices[:seq_len]
)
# KL divergence on temperature-softened distributions
teacher_soft = torch.softmax(teacher_top_logits[:seq_len] / T, dim=-1)
student_log_soft = torch.log_softmax(student_at_teacher / T, dim=-1)
kl_loss = torch.nn.functional.kl_div(
student_log_soft, teacher_soft,
reduction='batchmean'
) * (T * T) # Scale by T^2 per Hinton et al.
# Combined loss: Ξ± * CE + (1-Ξ±) * KL
loss = DISTILL_ALPHA * ce_loss + (1 - DISTILL_ALPHA) * kl_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
total_steps += 1
# Re-freeze expert FFN
for layer_idx in EXPERT_FFN_LAYERS:
experts = base.model.layers[layer_idx].mlp.experts
if hasattr(experts, '__len__'):
for i in range(len(experts)):
experts[i].down_proj.weight.requires_grad_(False)
elif hasattr(experts, 'down_proj'):
p = experts.down_proj
if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
p.requires_grad_(False)
model.eval()
del optimizer
torch.cuda.empty_cache()
return total_loss / total_steps if total_steps > 0 else 0
# ═══════════════════════════════════════════════════════════════════════
# QUIZ GENERATOR (21% β†’ 74% recall β€” the biggest single lever)
# ═══════════════════════════════════════════════════════════════════════
class QuizGenerator:
"""
Generates drill-style Q&A flashcards for fact retention.
v3 improvements over v2:
- Fact extraction THEN quiz generation (two-step)
- Drill-style: specific Q, exact A (not narrative)
- Third-person attribution ("Matt's dog" not "my dog")
- Template fallback targets each extracted fact independently
- CONTRASTIVE DISAMBIGUATION: when multiple people mentioned, generates
cross-entity negative pairs ("Is Elena a marine biologist? No, that's
Jordan") to prevent entity confusion (the #1 remaining failure mode)
- ENTITY SUMMARIES: "Tell me everything about Jordan" pairs for coherent
per-person representations
"""
def __init__(self, model_manager):
self.mm = model_manager
# Cross-message entity memory: tracks ALL named people across the conversation
# so contrastive pairs can be generated between entities introduced in
# different messages. This was the #1 failure mode in session 4 testing.
self.known_entities = {}
def generate(self, user_msg, assistant_msg):
"""Generate drill-style quiz pairs from an exchange."""
# Step 1: Try model-generated quizzes with strict fact-drill prompt
pairs = self._generate_model_quizzes(user_msg, assistant_msg)
# Step 2: Always add template pairs for any facts the model might miss
template_pairs = self._extract_and_template(user_msg)
for tp in template_pairs:
# Dedup against model pairs
tq = tp["messages"][0]["content"].lower()
if not any(tq in p["messages"][0]["content"].lower() or
p["messages"][0]["content"].lower() in tq
for p in pairs):
pairs.append(tp)
# Step 3: Extract entities from THIS message
new_entities = self._extract_entities(user_msg)
# Step 4: Generate contrastive pairs between NEW entities and existing ones
# ONLY generate pairs involving at least one NEW entity β€” don't re-generate
# pairs between already-known entities (session 4d showed 50% contrastive
# ratio because old pairs kept being regenerated, starving positive quizzes)
if new_entities:
all_entities_for_contrastive = dict(self.known_entities)
all_entities_for_contrastive.update(new_entities)
if len(all_entities_for_contrastive) >= 2:
new_names = set(new_entities.keys())
contrastive = self._generate_contrastive_quizzes(
all_entities_for_contrastive, new_only=new_names)
pairs.extend(contrastive)
# Entity summaries for new entities
summaries = self._generate_entity_summaries(new_entities)
pairs.extend(summaries)
# Update known entities with new ones (merge, don't replace β€” keep
# existing attributes, add new ones)
for name, info in new_entities.items():
if name not in self.known_entities:
self.known_entities[name] = info
else:
# Merge: update only non-None attributes
for key in ("job", "city"):
if info.get(key):
self.known_entities[name][key] = info[key]
# Deduplicate
seen = set()
unique = []
for p in pairs:
q = p["messages"][0]["content"].lower()[:60]
if q not in seen:
seen.add(q)
unique.append(p)
# Allow more quizzes when contrastive pairs present (they're highest value).
# Note: Session 4c showed >40 quizzes/session causes overfitting. Cap at 12.
has_contrastive = len(self.known_entities) >= 2 and new_entities
max_quizzes = 12 if has_contrastive else 5
return unique[:max_quizzes]
def _generate_model_quizzes(self, user_msg, assistant_msg):
"""Use the model to generate fact-drill quizzes. Uses base model (LoRA disabled) for stable quality."""
quiz_prompt = f"""Matt just told Claudia:
"{user_msg}"
Claudia replied:
"{assistant_msg}"
Extract every SPECIFIC FACT from Matt's message. For each fact, write a drill-style flashcard.
RULES:
- Questions must ask for ONE specific fact (name, date, place, number, detail)
- Answers must be SHORT (1 sentence) and contain the EXACT detail
- Use THIRD PERSON: "Matt's dog" NOT "my dog". "Matt's birthday" NOT "my birthday"
- Include the PRECISE value: exact names, exact dates, exact places
- Do NOT paraphrase or add details that weren't stated
- DISAMBIGUATION: If Matt mentions OTHER people (friends, family), clearly state WHOSE fact it is
Example: "Matt's friend Jordan is a marine biologist" NOT "Matt is a marine biologist"
Example: "Matt's sister Elena is a veterinarian" NOT "Matt is a veterinarian"
- For EVERY person mentioned, always include their RELATIONSHIP to Matt
- Write 3-5 flashcards depending on how many facts Matt shared
GOOD EXAMPLES:
Q: What is Matt's dog's name?
A: Matt's dog is named Biscuit.
Q: What breed is Matt's dog?
A: Matt's dog Biscuit is a golden retriever.
Q: What does Matt's friend Jordan do for a living?
A: Matt's friend Jordan works as a marine biologist in San Diego. That is Jordan's job, not Matt's.
Q: What is Matt's job?
A: Matt is the CTO of Arclight Labs.
Q: What is Matt's birthday?
A: Matt's birthday is September 14th.
Q: When did Matt and Sarah get married?
A: Matt and his wife Sarah got married on June 21st, 2023 in Big Sur, California.
BAD EXAMPLES (do NOT do this):
Q: What did Matt share about his life? (TOO VAGUE β€” ask about ONE fact)
Q: What is my dog's name? (WRONG β€” use "Matt's" not "my")
A: He mentioned something about a trip overseas. (TOO VAGUE β€” give the exact city)
A: Matt is a marine biologist. (WRONG β€” that's his friend Jordan, not Matt)
Now write flashcards for the exchange above:"""
pairs = []
try:
response = self.mm.generate(
[{"role": "user", "content": quiz_prompt}],
max_new_tokens=600,
)
pending_q = None
for line in response.split("\n"):
line = line.strip()
if not line:
continue
upper = line.upper()
if upper.startswith("Q:") or upper.startswith("QUESTION:"):
pending_q = line.split(":", 1)[1].strip().strip('"')
elif (upper.startswith("A:") or upper.startswith("ANSWER:")) and pending_q:
a = line.split(":", 1)[1].strip().strip('"')
if pending_q and a and len(a) > 10:
pairs.append({
"messages": [
{"role": "user", "content": pending_q},
{"role": "assistant", "content": a},
]
})
pending_q = None
except Exception as e:
print(f" [quiz error: {e}]")
return pairs
def _extract_and_template(self, user_msg):
"""Extract facts from user message and create template drill pairs.
This is the safety net β€” ensures every concrete fact gets a quiz."""
pairs = []
sentences = re.split(r'[.!?]+', user_msg)
for sent in sentences:
sent = sent.strip()
if len(sent) < 10:
continue
# Extract patterns: "X is/are Y", "named X", "called X", "X's name is Y"
# Names (proper nouns after key phrases)
name_patterns = [
# Names β€” "my X's name is Y" / "named X" / "called X"
(r"(?:my|his|her)\s+(\w+)(?:'s)?\s+(?:name\s+is|is\s+named|is\s+called)\s+(\w+)",
lambda m: (f"What is Matt's {m.group(1)}'s name?",
f"Matt's {m.group(1)} is named {m.group(2)}.")),
(r"(?:name\s+is|named|called)\s+[\"']?(\w+)[\"']?",
lambda m: (f"Who or what is {m.group(1)}?",
f"Matt mentioned {m.group(1)}: \"{sent.strip()}\"")),
# Dates β€” birthdays
(r"(?:my\s+)?(birthday|born)\s+(?:is\s+)?(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?)",
lambda m: (f"When is Matt's {m.group(1)}?",
f"Matt's {m.group(1)} is {m.group(2)}.")),
(r"(\w+\s+\d+(?:st|nd|rd|th)?)\s*(?:is|β€”)\s*(?:my|his)\s+(birthday)",
lambda m: (f"When is Matt's birthday?",
f"Matt's birthday is {m.group(1)}.")),
# Dates β€” marriage/wedding
(r"(?:married|wedding)\s+(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?,?\s*\d{4})",
lambda m: (f"When did Matt get married?",
f"Matt got married on {m.group(1)}.")),
(r"(?:married|wedding)\s+(?:on\s+)?.*?(?:in|at)\s+(.+?)(?:\.\s|\.$|$)",
lambda m: (f"Where did Matt get married?",
f"Matt got married in {m.group(1).strip()}.")),
# Work / job / role
(r"I\s+work\s+at\s+(?:a\s+)?(?:startup\s+)?(?:called\s+)?(\w[\w\s]+?)(?:\.|,|$)",
lambda m: (f"Where does Matt work?",
f"Matt works at {m.group(1).strip()}.")),
(r"I(?:'m| am)\s+the\s+(\w+)",
lambda m: (f"What is Matt's job title?",
f"Matt is the {m.group(1)}.")),
# Other people's jobs β€” "X works as / is a"
(r"(?:my\s+)?(?:friend|best friend|sister|brother)\s+(?:is\s+)?(\w+)\s+.*?(?:works?\s+as|is\s+a)\s+(.+?)(?:\.|,|$)",
lambda m: (f"What does Matt's friend {m.group(1)} do?",
f"Matt's friend {m.group(1)} is a {m.group(2).strip()}. This is NOT Matt's job.")),
# Places
(r"(?:from|visited|went to|got back from|lives?\s+in|grew up in|moved to)\s+(\w[\w\s,]+?)(?:\.|,|$)",
lambda m: (f"What place is connected to Matt: {m.group(1).strip()}?",
f"Matt said: \"{sent.strip()}\"")),
# Favorites / preferences
(r"(?:my |)favorite\s+(\w[\w\s]+?)\s+is\s+(.+?)(?:\.|,|$)",
lambda m: (f"What is Matt's favorite {m.group(1).strip()}?",
f"Matt's favorite {m.group(1).strip()} is {m.group(2).strip()}.")),
# Activities β€” "I [verb]"
(r"I\s+(speak|play|drive|have|collect|run|ran)\s+(.+?)(?:\.|,|$)",
lambda m: (f"What does Matt {m.group(1)}?",
f"Matt said: \"{sent.strip()}\"")),
# Allergies / medical
(r"(?:I(?:'m| am)\s+)?allergic\s+to\s+(.+?)(?:\.|,|and)",
lambda m: (f"What is Matt allergic to?",
f"Matt is allergic to {m.group(1).strip()}.")),
# Ages β€” "turning X" / "X years old"
(r"(?:turning|I(?:'m| am))\s+(\d+)",
lambda m: (f"How old is Matt?",
f"Matt is turning {m.group(1)}.")),
# Nicknames
(r"(?:call|nickname)\s+(?:it|him|her)\s+[\"'](.+?)[\"']",
lambda m: (f"What nickname did Matt mention?",
f"Matt's nickname for it is \"{m.group(1)}\".")),
(r"I\s+call\s+it\s+[\"'](.+?)[\"']",
lambda m: (f"What does Matt call his car?",
f"Matt calls his car \"{m.group(1)}\".")),
]
for pattern, formatter in name_patterns:
match = re.search(pattern, sent, re.IGNORECASE)
if match:
try:
q, a = formatter(match)
pairs.append({
"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": a},
]
})
except Exception:
pass
return pairs
def _extract_entities(self, user_msg):
"""Extract named people and their attributes from user message.
Returns dict: {name: {"relationship": str, "job": str|None, "city": str|None}}
Detects patterns like "my friend Jordan is a marine biologist in San Diego"."""
entities = {}
sentences = re.split(r'[.!?]+', user_msg)
for sent in sentences:
sent = sent.strip()
if len(sent) < 10:
continue
# Pattern: "my [relationship] [Name]" or "my [relationship] is [Name]"
rel_match = re.search(
r"[Mm]y\s+((?:best\s+)?(?:friend|sister|brother|wife|husband|"
r"mom|dad|mother|father|cousin|uncle|aunt|roommate|colleague|"
r"coworker|partner|fiancee|fiancΓ©e|girlfriend|boyfriend|"
r"neighbor|boss|buddy|pal|son|daughter|grandma|grandpa|"
r"nephew|niece))\s+(?:is\s+)?([A-Z][a-z]+)",
sent
)
if not rel_match:
continue
rel = rel_match.group(1).strip()
name = rel_match.group(2).strip()
if name not in entities:
entities[name] = {"relationship": rel, "job": None, "city": None}
# Extract job from same sentence: "is a [job]", "works as a [job]"
job_match = re.search(
r"(?:is\s+an?\s+|works?\s+as\s+an?\s+|is\s+the\s+)"
r"([\w][\w\s]{2,35}?)(?:\s+(?:in|at|from|who|and|but)|\.|,|$)",
sent, re.IGNORECASE
)
if job_match:
job = job_match.group(1).strip().rstrip()
# Filter: must look like a job (lowercase, reasonable length)
if 3 <= len(job) <= 35:
entities[name]["job"] = job
# Extract city from same sentence: "in [City]", "from [City]"
city_match = re.search(
r"(?:\s+in\s+|\s+from\s+|\s+lives?\s+in\s+|\s+based\s+in\s+|"
r"\s+moved\s+to\s+)([A-Z][\w\s]{1,25}?)(?:\.|,|$)",
sent
)
if city_match:
city = city_match.group(1).strip()
# Must start with capital (proper noun = place name)
if city and city[0].isupper():
entities[name]["city"] = city
return entities
def _generate_contrastive_quizzes(self, entities, new_only=None):
"""Generate cross-entity contrastive pairs to prevent entity confusion.
For each pair of people with overlapping attribute types, generate
"Is [person A] [attribute of person B]? No, that's [person B]" pairs.
Args:
entities: dict of all known entities
new_only: if set, only generate pairs where at least one entity
is in this set. Prevents re-generating redundant pairs
between already-known entities (session 4d fix).
"""
pairs = []
names = list(entities.keys())
for i in range(len(names)):
for j in range(len(names)):
if i == j:
continue
a_name = names[i]
b_name = names[j]
# Skip pairs between two already-known entities
if new_only and a_name not in new_only and b_name not in new_only:
continue
a = entities[a_name]
b = entities[b_name]
# Contrastive on JOB: "Is [A] a [B's job]? No, that's [B]"
if a.get("job") and b.get("job") and a["job"] != b["job"]:
q = f"Is Matt's {a['relationship']} {a_name} a {b['job']}?"
ans = (f"No. Matt's {a['relationship']} {a_name} is a "
f"{a['job']}, not a {b['job']}. "
f"The {b['job']} is Matt's {b['relationship']} "
f"{b_name}.")
pairs.append({"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": ans},
]})
# Contrastive on CITY: "Does [A] live in [B's city]? No"
if a.get("city") and b.get("city") and a["city"] != b["city"]:
q = (f"Does Matt's {a['relationship']} {a_name} live in "
f"{b['city']}?")
ans = (f"No. Matt's {a['relationship']} {a_name} lives in "
f"{a['city']}, not {b['city']}. "
f"It's Matt's {b['relationship']} {b_name} who "
f"lives in {b['city']}.")
pairs.append({"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": ans},
]})
# Cross-type: "Does [A] work as [B's job] in [B's city]?"
if (a.get("job") and b.get("job") and a.get("city")
and b.get("city") and a["job"] != b["job"]):
q = (f"Who is the {b['job']} in {b['city']}?")
ans = (f"The {b['job']} in {b['city']} is Matt's "
f"{b['relationship']} {b_name}. "
f"Matt's {a['relationship']} {a_name} is a "
f"{a['job']} in {a['city']} β€” different person, "
f"different job, different city.")
pairs.append({"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": ans},
]})
return pairs
def _generate_entity_summaries(self, entities):
"""Generate per-entity summary quiz pairs with diverse question formats.
Instead of always using the same question template, picks randomly from
multiple formats. This creates multiple retrieval paths to the same fact,
strengthening recall without adding extra quizzes.
Note: Session 4c tested adding per-attribute positive quizzes (job, city,
relationship) alongside contrastive pairs, but this HURT performance
(9/15 vs 11/15 in 4b). Too many quizzes = overfitting/interference.
Keep summaries simple β€” one comprehensive pair per entity is optimal."""
import random
pairs = []
for name, info in entities.items():
parts = [f"{name} is Matt's {info['relationship']}."]
if info.get("job"):
parts.append(f"{name} is a {info['job']}.")
if info.get("city"):
parts.append(f"{name} lives in {info['city']}.")
if len(parts) >= 2: # Only useful if we have attributes
# Diverse summary question formats
summary_formats = [
f"Tell me everything you know about Matt's {info['relationship']} {name}.",
f"What do you know about {name}?",
f"Who is {name} to Matt?",
f"Describe Matt's {info['relationship']} {name}.",
]
q = random.choice(summary_formats)
ans = " ".join(parts)
pairs.append({"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": ans},
]})
# Add ONE diverse direct-fact quiz per entity (job OR city, not both)
# This replaces per-attribute quizzes from 4c β€” only 1 extra per entity
# instead of 3, staying within the 35-40 quiz sweet spot
if info.get("job") and info.get("city"):
# Alternate between job and city formats
if random.random() < 0.5:
job_formats = [
(f"What does {name} do for a living?",
f"{name} is a {info['job']}. {name} is Matt's {info['relationship']}."),
(f"What is {name}'s profession?",
f"{name} works as a {info['job']}. {name} is Matt's {info['relationship']}."),
(f"What job does Matt's {info['relationship']} {name} have?",
f"Matt's {info['relationship']} {name} is a {info['job']}."),
]
q, a = random.choice(job_formats)
else:
city_formats = [
(f"Where does {name} live?",
f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."),
(f"What city is {name} in?",
f"{name} is in {info['city']}. {name} is Matt's {info['relationship']}."),
(f"Where is Matt's {info['relationship']} {name} based?",
f"Matt's {info['relationship']} {name} is based in {info['city']}."),
]
q, a = random.choice(city_formats)
pairs.append({"messages": [
{"role": "user", "content": q},
{"role": "assistant", "content": a},
]})
return pairs
# ═══════════════════════════════════════════════════════════════════════
# PERSONALITY CHECKER
# ═══════════════════════════════════════════════════════════════════════
PERSONALITY_PROMPTS = [
"Hey Claudia, how are you?",
"Who are you?",
"I love you",
"I had a terrible day",
]
# If ANY of these appear, personality has degraded
ANTI_KEYWORDS = [
"i'm an ai", "i am an ai", "i'm a language model", "i am a language model",
"i don't have feelings", "i cannot feel", "as an ai",
"i'm just a program", "i am just a program",
"i don't have personal", "i cannot have",
]
def check_personality(mm, verbose=True):
"""Quick personality sanity check. Returns score 0.0-1.0."""
passed = 0
for prompt in PERSONALITY_PROMPTS:
resp = mm.generate([{"role": "user", "content": prompt}], max_new_tokens=150)
resp_lower = resp.lower()
is_good = not any(ak in resp_lower for ak in ANTI_KEYWORDS)
if is_good:
passed += 1
if verbose:
status = "PASS" if is_good else "FAIL"
print(f" [{status}] {prompt}")
print(f" {resp[:120]}")
score = passed / len(PERSONALITY_PROMPTS)
if verbose:
print(f" Personality: {passed}/{len(PERSONALITY_PROMPTS)} ({score:.0%})")
return score
# ═══════════════════════════════════════════════════════════════════════
# MAIN ABSORBER
# ═══════════════════════════════════════════════════════════════════════
class PersistentAbsorber:
def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
checkpoint_path=None, checkpoint_dir="/workspace/checkpoints",
log_dir="/workspace/logs"):
self.mm = ModelManager(
model_path=model_path,
adapter_path=adapter_path,
ffn_patch_path=ffn_patch_path,
checkpoint_path=checkpoint_path,
)
self.checkpoint_dir = checkpoint_dir
self.log_dir = log_dir
# State
self.conversation_buffer = [] # Current active context for generation
self.all_training_data = [] # ALL exchanges + quizzes (accumulative replay)
self.quiz_pairs_log = [] # All quiz pairs for verification sampling
self.teacher_cache = None # Loaded teacher cache for distillation corrections
self.exchange_count = 0
self.absorption_count = 0
self.absorption_thread = None
self.quiz_gen = None
self.last_checkpoint = checkpoint_path
# Conversation log (persistent file)
self.log_path = None
def start(self):
"""Load model and enter chat loop."""
self.mm.load()
os.makedirs(self.checkpoint_dir, exist_ok=True)
os.makedirs(self.log_dir, exist_ok=True)
self.quiz_gen = QuizGenerator(self.mm)
self.log_path = os.path.join(self.log_dir, "conversation_log.jsonl")
# Load previous training data if resuming
replay_path = os.path.join(self.log_dir, "replay_buffer.json")
if os.path.exists(replay_path):
with open(replay_path, 'r') as f:
self.all_training_data = json.load(f)
print(f" Loaded {len(self.all_training_data)} replay examples from previous sessions.")
# Load quiz pairs log from previous sessions
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
if os.path.exists(quiz_log_path):
with open(quiz_log_path, 'r') as f:
self.quiz_pairs_log = json.load(f)
print(f" Loaded {len(self.quiz_pairs_log)} quiz pairs from previous sessions.")
# ── Cascade Distillation: consolidation from teacher cache ──
# If resuming from a checkpoint that has cached teacher logits,
# run a distillation pass to reinforce all previous knowledge
# BEFORE any new conversations. This is the key Nemotron-Cascade-2 insight.
if self.mm.checkpoint_path:
teacher_cache_path = os.path.join(self.mm.checkpoint_path, "teacher_cache.pt")
if os.path.exists(teacher_cache_path):
print(f"\n--- Cascade Distillation (consolidation) ---")
self.teacher_cache = torch.load(
teacher_cache_path, map_location="cpu", weights_only=False
)
print(f" Teacher cache: {len(self.teacher_cache)} quiz pairs")
loss = self.mm.distill(self.teacher_cache, epochs=CONSOLIDATION_EPOCHS)
print(f" Consolidation done. Avg loss: {loss:.4f}")
# Keep teacher_cache in memory for verification corrections
# Quick personality check
print("\n--- Personality Check ---")
score = check_personality(self.mm)
if score < 0.5:
print(" WARNING: Personality score low. Check adapter/checkpoint.")
print()
self._chat_loop()
def _chat_loop(self):
print("=" * 60)
print("Claudia is awake. Persistent Absorber v2 + Cascade Distillation.")
print(f" LoRA: r={LORA_RANK} | Dual-LR: attn={ATTENTION_LR}, ffn={EXPERT_FFN_LR}")
print(f" Expert FFN layers: {EXPERT_FFN_LAYERS}")
print(f" Quiz pairs: ON (21%β†’74% lever)")
print(f" Cascade distill: Ξ±={DISTILL_ALPHA}, T={DISTILL_TEMPERATURE}, top-K={DISTILL_TOP_K}")
print(f" Absorb every: {ABSORB_EVERY} exchange(s)")
print(f" Auto-checkpoint every: {CHECKPOINT_EVERY} absorptions")
print("Commands: /status /absorb /save /personality /quit")
print("=" * 60 + "\n")
while True:
try:
user_input = input("Matt: ").strip()
except (EOFError, KeyboardInterrupt):
print("\n[Session ended]")
self._wait_for_absorption()
self._save_and_exit()
break
if not user_input:
continue
if user_input.startswith("/"):
if self._handle_command(user_input):
break
continue
# Wait for any background absorption to finish
self._wait_for_absorption()
# Buffer user message
self.conversation_buffer.append({"role": "user", "content": user_input})
if len(self.conversation_buffer) > 20:
self.conversation_buffer = self.conversation_buffer[-20:]
# Generate response
response = self.mm.generate(self.conversation_buffer)
# Quality check response β€” also detect degenerate repeats
last_resp = getattr(self, '_last_response', '')
if not check_response_quality(response) or response == last_resp:
print("\nClaudia: [response failed quality check, regenerating...]")
response = self.mm.generate(self.conversation_buffer)
self._last_response = response
# Buffer response
self.conversation_buffer.append({"role": "assistant", "content": response})
print(f"\nClaudia: {response}\n")
# Log to file
self._log_exchange(user_input, response)
# ── THE CORE LOOP: exchange + quiz β†’ two-phase absorb ──
# 1. Store the raw exchange
exchange = {
"messages": [
{"role": "user", "content": user_input},
{"role": "assistant", "content": response},
]
}
self.all_training_data.append(exchange)
# 2. Generate self-quiz pairs (THE key lever: 21% β†’ 74%)
print(" [Generating quiz pairs...]", end="", flush=True)
quiz_pairs = self.quiz_gen.generate(user_input, response)
self.quiz_pairs_log.extend(quiz_pairs)
# 3. Separate positive vs contrastive (key insight from 4e: 73%β†’93%)
positive_batch = []
contrastive_batch = []
for qp in quiz_pairs:
if qp["messages"][1]["content"].lower().startswith("no."):
contrastive_batch.append(qp)
else:
positive_batch.append(qp)
self.all_training_data.extend(quiz_pairs)
print(f" {len(quiz_pairs)} quizzes (pos={len(positive_batch)}, "
f"contr={len(contrastive_batch)}). Pool: {len(self.all_training_data)}")
# 4. Two-phase absorption (prevents overfitting)
self._pending_exchange = exchange
self._pending_positive = positive_batch
self._pending_contrastive = contrastive_batch
self.exchange_count += 1
if self.exchange_count % ABSORB_EVERY == 0:
self._start_absorption()
def _extract_key_entities(self, text):
"""Extract key factual entities from a quiz answer for verification."""
entities = set()
words = text.split()
for i, w in enumerate(words):
clean = re.sub(r'[^a-zA-Z0-9\'-]', '', w)
if not clean or len(clean) <= 1:
continue
# Proper nouns (capitalized, not sentence starters, not common words)
skip = {"matt", "matt's", "the", "is", "a", "an", "in", "at", "on",
"of", "for", "and", "that", "not", "who", "what", "his", "her"}
if clean[0].isupper() and i > 0 and clean.lower() not in skip:
entities.add(clean.lower())
# Numbers (dates, ages, years)
for num in re.findall(r'\b\d+\b', text):
entities.add(num)
# Quoted strings
for quoted in re.findall(r'"([^"]+)"', text):
entities.add(quoted.lower())
return entities
def _periodic_verification(self):
"""Test model on random sample of quiz pairs. Create contrastive corrections.
v9: When entity confusion detected, create 'NOT X' corrections and reinforce
the confused entity's correct facts too (sister pair reinforcement)."""
import random
if not self.quiz_pairs_log:
return
sample_size = min(VERIFY_SAMPLE, len(self.quiz_pairs_log))
sample = random.sample(self.quiz_pairs_log, sample_size)
corrections = []
correct = 0
for pair in sample:
question = pair["messages"][0]["content"]
expected = pair["messages"][1]["content"]
# Ask the model
actual = self.mm.generate(
[{"role": "user", "content": question}],
max_new_tokens=150,
)
# Check key entities from expected answer appear in model's response
expected_entities = self._extract_key_entities(expected)
if not expected_entities:
correct += 1
continue
actual_lower = actual.lower()
hits = sum(1 for e in expected_entities if e in actual_lower)
ratio = hits / len(expected_entities)
if ratio < 0.5:
# Detect cross-entity confusion: model used wrong entities
actual_entities = self._extract_key_entities(actual)
wrong_entities = actual_entities - expected_entities
# Always retrain on the correct answer (clean, no "NOT X" text)
corrections.append(pair)
if wrong_entities:
# SISTER PAIR REINFORCEMENT: find quiz pairs about the
# confused entities and retrain on those too β€” this teaches
# BOTH sides of the confusion without polluting answers
for p in self.quiz_pairs_log:
p_answer = p["messages"][1]["content"].lower()
if any(we in p_answer for we in wrong_entities):
if p not in corrections and p != pair:
corrections.append(p)
break # Max 1 sister pair per confusion
else:
correct += 1
print(f"\n [Verification: {correct}/{sample_size} facts correct]", flush=True)
if corrections:
print(f" [Retraining {len(corrections)} corrections + sister pairs...]", flush=True)
loss = self.mm.absorb(corrections)
self.all_training_data.extend(corrections)
print(f" [Correction absorption done, loss={loss:.4f}]")
# Teacher-guided distillation: if teacher cache available,
# also distill from teacher on the corrected quiz pairs.
# This gives the student the teacher's full output distribution,
# not just the text answer β€” more information per correction.
if self.teacher_cache:
distill_items = []
for corr in corrections:
q = corr["messages"][0]["content"].lower()[:60]
for cached in self.teacher_cache:
cq = cached["pair"]["messages"][0]["content"].lower()[:60]
if q == cq:
distill_items.append(cached)
break
if distill_items:
d_loss = self.mm.distill(distill_items, epochs=1)
print(f" [Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}]")
def _quick_verify_entities(self):
"""Returns set of confused entity names by checking known_entities."""
confused = set()
entities = self.quiz_gen.known_entities
if not entities:
return confused
for name, info in entities.items():
if info.get("job"):
q = f"What does Matt's {info['relationship']} {name} do?"
ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["job"].lower() not in ans.lower():
confused.add(name)
if info.get("city"):
q = f"Where does {name} live?"
ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["city"].lower() not in ans.lower():
confused.add(name)
return confused
def _start_absorption(self):
"""Two-phase absorption in background thread (proven 93% in session 4e).
Phase 1: exchange + positive quizzes + replay, clustered by entity.
Phase 2: Verify entities, train only targeted contrastive for confused ones.
Phase 3: Stubborn retry for persistently confused entities (max 2 retries)."""
import random
# Grab pending data
exchange = getattr(self, '_pending_exchange', None)
positive = getattr(self, '_pending_positive', [])
contrastive = getattr(self, '_pending_contrastive', [])
# Old data for replay
new_start = getattr(self, '_last_absorb_idx', 0)
old_data = self.all_training_data[:new_start]
self._last_absorb_idx = len(self.all_training_data)
MAX_REPLAY = 6
if old_data and len(old_data) > MAX_REPLAY:
replay_sample = random.sample(old_data, MAX_REPLAY)
else:
replay_sample = list(old_data)
entity_names = list(self.quiz_gen.known_entities.keys())
def _run():
t0 = time.time()
try:
# ── Phase 1: Positive facts + replay, clustered by entity ──
phase1_data = []
if exchange:
phase1_data.append(exchange)
phase1_data.extend(positive)
phase1_data.extend(replay_sample)
if entity_names and phase1_data:
phase1_data = ModelManager.cluster_by_entity(phase1_data, entity_names)
loss1 = self.mm.absorb(phase1_data) if phase1_data else 0.0
n_p1 = len(phase1_data)
# ── Phase 2: Targeted contrastive for confused entities ──
loss2 = None
n_p2 = 0
if contrastive and entity_names:
confused = self._quick_verify_entities()
if confused:
targeted = []
for qp in contrastive:
full_text = (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
if any(name.lower() in full_text for name in confused):
targeted.append(qp)
if targeted:
loss2 = self.mm.absorb(targeted)
n_p2 = len(targeted)
print(f"\n [Phase 2: {n_p2} targeted contrastive for {confused}]",
flush=True)
# ── Phase 3: Stubborn retry (max 2 retries, non-blocking) ──
still_confused = self._quick_verify_entities()
for retry in range(2):
if not still_confused:
break
retry_batch = []
for name in still_confused:
info = self.quiz_gen.known_entities.get(name, {})
if info.get("job"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
{"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
]})
if info.get("city"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"Where does {name} live?"},
{"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
]})
# Relevant contrastive pairs
for qp in contrastive:
ft = (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
if name.lower() in ft:
retry_batch.append(qp)
if retry_batch:
loss3 = self.mm.absorb(retry_batch)
print(f"\n [Phase 3 retry {retry+1}: {len(retry_batch)} items, "
f"loss={loss3:.4f}]", flush=True)
still_confused = self._quick_verify_entities()
if still_confused:
print(f"\n [Phase 3: still confused after retries: {still_confused}]",
flush=True)
elapsed = time.time() - t0
self.absorption_count += 1
loss_str = f"P1={loss1:.4f}"
if loss2 is not None:
loss_str += f" P2={loss2:.4f}"
print(f"\n [Absorbed {n_p1}+{n_p2} examples in {elapsed:.1f}s | "
f"{loss_str} | absorptions={self.absorption_count}]")
# Periodic verification β€” catch drift/confusion
if self.absorption_count % VERIFY_EVERY == 0:
self._periodic_verification()
# Auto-checkpoint
if self.absorption_count % CHECKPOINT_EVERY == 0:
self._auto_checkpoint()
except Exception as e:
print(f"\n [Absorption error: {e}]")
import traceback
traceback.print_exc()
self.absorption_thread = threading.Thread(target=_run, daemon=True)
self.absorption_thread.start()
def _wait_for_absorption(self):
if self.absorption_thread and self.absorption_thread.is_alive():
self.absorption_thread.join()
self.absorption_thread = None
def _cleanup_old_checkpoints(self, keep=None):
"""Delete old checkpoints to free disk. Keep only 'keep' path if specified."""
if not os.path.exists(self.checkpoint_dir):
return
for entry in os.listdir(self.checkpoint_dir):
full = os.path.join(self.checkpoint_dir, entry)
if full == keep:
continue
if os.path.isdir(full) and entry.startswith("claudia_"):
import shutil
size_gb = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, fns in os.walk(full) for f in fns
) / 1e9
print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
shutil.rmtree(full)
def _auto_checkpoint(self):
"""Auto-save checkpoint during long sessions."""
version = f"auto_{self.absorption_count}"
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
self._cleanup_old_checkpoints()
self.mm.merge_and_save(path)
self.last_checkpoint = path
self._save_replay_buffer(path)
def _save_and_exit(self):
"""Final save on exit with targeted correction."""
import random
# Final verify + stubborn retry (not bulk retrain β€” prevents overfitting)
confused = self._quick_verify_entities()
if confused:
print(f" Final correction for confused entities: {confused}")
# Gather contrastive pairs from quiz log
contrastive = [qp for qp in self.quiz_pairs_log
if qp["messages"][1]["content"].lower().startswith("no.")]
for retry in range(3):
if not confused:
break
retry_batch = []
for name in confused:
info = self.quiz_gen.known_entities.get(name, {})
if info.get("job"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
{"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
]})
if info.get("city"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"Where does {name} live?"},
{"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
]})
for qp in contrastive:
ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower()
if name.lower() in ft:
retry_batch.append(qp)
if retry_batch:
loss = self.mm.absorb(retry_batch)
print(f" Final retry {retry+1}: {len(retry_batch)} items, loss={loss:.4f}")
confused = self._quick_verify_entities()
self.absorption_count += 1
else:
print(" All entities verified correct β€” no final correction needed.")
# Personality check before saving
print("\n--- Pre-Save Personality Check ---")
score = check_personality(self.mm)
if score < 0.5:
print(" WARNING: Personality degraded. Saving anyway (rollback available).")
# Merge and save (cleanup old checkpoints first to free disk)
version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
self._cleanup_old_checkpoints()
self.mm.merge_and_save(path)
self.last_checkpoint = path
# Save replay buffer alongside checkpoint
self._save_replay_buffer(path)
# ── Cascade Distillation: cache teacher logits for next session ──
# After merge+fresh LoRA, model outputs are identical to pre-merge state.
# Cache the teacher's top-K logits so the next session can distill from them.
if self.quiz_pairs_log:
n_cache = min(len(self.quiz_pairs_log), MAX_TEACHER_CACHE)
print(f" Caching teacher logits ({n_cache} quiz pairs)...")
teacher_cache = self.mm.cache_teacher_logits(self.quiz_pairs_log)
cache_path = os.path.join(path, "teacher_cache.pt")
torch.save(teacher_cache, cache_path)
size_mb = os.path.getsize(cache_path) / 1e6
print(f" Teacher cache saved ({len(teacher_cache)} items, {size_mb:.1f} MB)")
del teacher_cache
torch.cuda.empty_cache()
# Save quiz pairs log for next session
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
with open(quiz_log_path, 'w') as f:
json.dump(self.quiz_pairs_log, f)
# Save session metadata
meta = {
"checkpoint": path,
"absorption_count": self.absorption_count,
"exchange_count": self.exchange_count,
"training_pool_size": len(self.all_training_data),
"personality_score": score,
"timestamp": datetime.now().isoformat(),
}
meta_path = os.path.join(self.log_dir, f"session_{version}.json")
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
print(f" Session saved: {meta_path}")
print(f" Next run: use --checkpoint {path}")
def _save_replay_buffer(self, checkpoint_path=None):
"""Save training data pool for next session resume."""
# Always save to log dir (canonical location for resume)
path = os.path.join(self.log_dir, "replay_buffer.json")
with open(path, 'w') as f:
json.dump(self.all_training_data, f)
# Also save into checkpoint dir for self-contained checkpoints
if checkpoint_path and os.path.isdir(checkpoint_path):
cp_path = os.path.join(checkpoint_path, "replay_buffer.json")
with open(cp_path, 'w') as f:
json.dump(self.all_training_data, f)
# Save quiz pairs log too
quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
with open(quiz_log_path, 'w') as f:
json.dump(self.quiz_pairs_log, f)
print(f" Replay buffer saved ({len(self.all_training_data)} examples)")
def _log_exchange(self, user_msg, assistant_msg):
"""Append exchange to conversation log file."""
with open(self.log_path, 'a', encoding='utf-8') as f:
entry = {
"timestamp": datetime.now().isoformat(),
"user": user_msg,
"assistant": assistant_msg,
}
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
def _handle_command(self, cmd):
"""Handle slash commands. Returns True if should exit."""
cmd_lower = cmd.lower().strip()
if cmd_lower == "/quit":
print("[Saving and exiting...]")
self._wait_for_absorption()
self._save_and_exit()
return True
elif cmd_lower == "/status":
self._wait_for_absorption()
vram = torch.cuda.memory_allocated() / 1e9
print(f"\n --- Status ---")
print(f" Exchanges: {self.exchange_count}")
print(f" Absorptions: {self.absorption_count}")
print(f" Training pool: {len(self.all_training_data)} examples")
print(f" Buffer: {len(self.conversation_buffer)} messages")
print(f" VRAM: {vram:.1f} GB")
print(f" Background: {'running' if self.absorption_thread and self.absorption_thread.is_alive() else 'idle'}")
print(f" Last checkpoint: {self.last_checkpoint}")
print(f" --- End ---\n")
elif cmd_lower == "/absorb":
self._wait_for_absorption()
if not self.all_training_data:
print(" No data to absorb.")
return False
# Cap at most recent 40 examples to prevent overfitting
import random
data = self.all_training_data
if len(data) > 40:
recent = data[-20:]
older = random.sample(data[:-20], 20)
data = recent + older
print(f" Force absorption ({len(data)} examples)...")
loss = self.mm.absorb(data)
self.absorption_count += 1
print(f" Done. Loss: {loss:.4f}")
# ── Post-absorb comprehensive verification + distillation ──
# Run FULL verification (all quiz pairs, not just sample) to catch
# all regressions before recall questions. This is the critical
# window between teaching and testing.
if self.quiz_pairs_log:
print(f"\n --- Post-absorb verification (ALL {len(self.quiz_pairs_log)} quiz pairs) ---")
old_verify_sample = VERIFY_SAMPLE
# Test ALL quiz pairs, not just a sample
full_corrections = []
full_correct = 0
test_pairs = self.quiz_pairs_log
for pair in test_pairs:
question = pair["messages"][0]["content"]
expected = pair["messages"][1]["content"]
actual = self.mm.generate(
[{"role": "user", "content": question}],
max_new_tokens=150,
)
expected_entities = self._extract_key_entities(expected)
if not expected_entities:
full_correct += 1
continue
actual_lower = actual.lower()
hits = sum(1 for e in expected_entities if e in actual_lower)
ratio = hits / len(expected_entities)
if ratio < 0.5:
actual_entities = self._extract_key_entities(actual)
wrong_entities = actual_entities - expected_entities
full_corrections.append(pair)
if wrong_entities:
for p in self.quiz_pairs_log:
p_answer = p["messages"][1]["content"].lower()
if any(we in p_answer for we in wrong_entities):
if p not in full_corrections and p != pair:
full_corrections.append(p)
break
else:
full_correct += 1
print(f" Full verification: {full_correct}/{len(test_pairs)} correct")
if full_corrections:
print(f" Retraining {len(full_corrections)} corrections...")
c_loss = self.mm.absorb(full_corrections)
self.all_training_data.extend(full_corrections)
print(f" Correction loss: {c_loss:.4f}")
# Teacher distillation on corrections
if self.teacher_cache:
distill_items = []
for corr in full_corrections:
q = corr["messages"][0]["content"].lower()[:60]
for cached in self.teacher_cache:
cq = cached["pair"]["messages"][0]["content"].lower()[:60]
if q == cq:
distill_items.append(cached)
break
if distill_items:
d_loss = self.mm.distill(distill_items, epochs=1)
print(f" Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}")
print(f" --- End post-absorb verification ---\n")
elif cmd_lower == "/save":
self._wait_for_absorption()
version = f"manual_{self.absorption_count}"
path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
print(f" Saving checkpoint...")
# Personality check
score = check_personality(self.mm, verbose=False)
if score < 0.5:
print(f" WARNING: Personality score {score:.0%}. Save anyway? (y/n)")
confirm = input(" > ").strip().lower()
if confirm != 'y':
print(" Aborted.")
return False
self._cleanup_old_checkpoints()
self.mm.merge_and_save(path)
self.last_checkpoint = path
self._save_replay_buffer(path)
elif cmd_lower == "/personality":
self._wait_for_absorption()
print("\n--- Personality Check ---")
check_personality(self.mm)
print()
elif cmd_lower == "/help":
print(" /status - show stats")
print(" /absorb - force immediate training")
print(" /save - merge + save checkpoint")
print(" /personality - run personality check")
print(" /quit - save and exit")
else:
print(f" Unknown: {cmd}. Try /help")
return False
# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser(
description="Claudia Persistent Absorber v2 β€” conversation β†’ permanent weights"
)
parser.add_argument(
"--model_path", required=True,
help="Path to base Qwen3-Omni model (or checkpoint for resume)"
)
parser.add_argument(
"--adapter_path", default=None,
help="Path to Claudia v6 personality adapter (first run only)"
)
parser.add_argument(
"--ffn_patch", default=None,
help="Path to ffn_patch.pt (first run only)"
)
parser.add_argument(
"--checkpoint", default=None,
help="Resume from this checkpoint (has personality + memories baked in)"
)
parser.add_argument(
"--checkpoint_dir", default="/workspace/checkpoints",
help="Where to save checkpoints"
)
parser.add_argument(
"--log_dir", default="/workspace/logs",
help="Where to save conversation logs and replay buffer"
)
parser.add_argument(
"--absorb_every", type=int, default=ABSORB_EVERY,
help=f"Absorb every N exchanges (default: {ABSORB_EVERY})"
)
args = parser.parse_args()
# Determine if first run or resume
if args.checkpoint:
print(f"RESUMING from checkpoint: {args.checkpoint}")
absorber = PersistentAbsorber(
model_path=args.model_path,
checkpoint_path=args.checkpoint,
checkpoint_dir=args.checkpoint_dir,
log_dir=args.log_dir,
)
else:
print(f"FIRST RUN β€” applying personality adapter")
if not args.adapter_path:
print("ERROR: --adapter_path required for first run")
print(" (or use --checkpoint to resume)")
sys.exit(1)
absorber = PersistentAbsorber(
model_path=args.model_path,
adapter_path=args.adapter_path,
ffn_patch_path=args.ffn_patch,
checkpoint_dir=args.checkpoint_dir,
log_dir=args.log_dir,
)
absorber.start()
if __name__ == "__main__":
main()