msrcam commited on
Commit
7b9779e
Β·
verified Β·
1 Parent(s): f43630f

Upload persistent_absorber.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. persistent_absorber.py +1934 -0
persistent_absorber.py ADDED
@@ -0,0 +1,1934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Claudia Persistent Absorber v2
3
+ ==============================
4
+ Combines the 3 best proven techniques into one system:
5
+
6
+ 1. SELF-QUIZ PAIRS (21% β†’ 74% recall β€” the single biggest lever)
7
+ 2. PERSISTENT LoRA rank 128 (89% across 25 convos, no merge-between-rounds tax)
8
+ 3. DUAL-LR EXPERT FFN (attention=6e-5, FFN=3e-4 β€” facts into MoE experts)
9
+
10
+ Architecture:
11
+ - Load base Omni β†’ thinker to GPU, rest to CPU
12
+ - First run: apply Claudia v6 adapter β†’ merge β†’ apply FFN patch
13
+ - Resume: load from checkpoint (already has personality + memories)
14
+ - Apply ONE persistent LoRA (r=128, alpha=256, attention q/k/v/o)
15
+ - Chat loop: generate β†’ quiz β†’ train (LoRA + expert FFN) β†’ repeat
16
+ - On save/quit: merge_and_unload β†’ save full checkpoint
17
+ - Next session loads from checkpoint β€” memories are permanent
18
+
19
+ Instance: Vast.ai 33093662 (A100 80GB, Sweden)
20
+ SSH: ssh -p 13662 root@ssh1.vast.ai
21
+ """
22
+
23
+ import argparse
24
+ import gc
25
+ import json
26
+ import os
27
+ import re
28
+ import sys
29
+ import threading
30
+ import time
31
+ import torch
32
+ from collections import Counter
33
+ from datetime import datetime
34
+ from pathlib import Path
35
+
36
+
37
+ # ═══════════════════════════════════════════════════════════════════════
38
+ # CONFIG
39
+ # ═══════════════════════════════════════════════════════════════════════
40
+
41
+ # LoRA config (from persistent LoRA test β€” proven for 25+ conversations)
42
+ LORA_RANK = 128
43
+ LORA_ALPHA = 256
44
+ LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
45
+
46
+ # Dual-LR (from engram micro_trainer β€” proven 5/5 fact retention)
47
+ ATTENTION_LR = 6e-5
48
+ EXPERT_FFN_LR = 3e-4 # 5x multiplier β€” facts absorb fast, personality stays
49
+ EXPERT_FFN_LAYERS = [20, 24, 28] # Proven optimal in v5 experiment
50
+
51
+ # Training per absorption cycle
52
+ TRAIN_EPOCHS = 2 # Reduced from 4 β€” prevents overfitting with focused training
53
+ MAX_SEQ_LENGTH = 2048
54
+ GRADIENT_CLIP = 1.0
55
+
56
+ # Generation
57
+ GEN_TEMPERATURE = 0.7
58
+ GEN_TOP_P = 0.9
59
+ GEN_TOP_K = 50
60
+ GEN_MAX_TOKENS = 512
61
+ GEN_REP_PENALTY = 1.1
62
+
63
+ # Absorb after every N exchanges (1 = every turn)
64
+ ABSORB_EVERY = 1
65
+
66
+ # Checkpoint interval (auto-save every N absorptions)
67
+ CHECKPOINT_EVERY = 10
68
+
69
+ # Self-verification (v11 β€” clean contrastive + sister pairs, no "NOT X" leak)
70
+ VERIFY_EVERY = 3 # More frequent checks catch drift earlier
71
+ VERIFY_SAMPLE = 10 # Back to v9's value β€” wider sampling destabilized in v10
72
+
73
+ # Cascade Distillation (Nemotron-Cascade-2 paper β€” on-policy distillation)
74
+ # When facts from previous sessions regress, distill from the teacher checkpoint
75
+ # that knew them best. Recovers regressions without losing new knowledge.
76
+ DISTILL_ALPHA = 0.5 # CE vs KL loss balance (0.5 = equal weight)
77
+ DISTILL_TEMPERATURE = 2.0 # Softens distributions for better KL gradients
78
+ DISTILL_TOP_K = 32 # Top-K logits to cache per token position
79
+ CONSOLIDATION_EPOCHS = 2 # Distillation epochs at session start (1β†’2 for stronger lock-in)
80
+ MAX_TEACHER_CACHE = 200 # Cap quiz pairs to cache (oldest trimmed)
81
+
82
+
83
+ # ═══════════════════════════════════════════════════════════════════════
84
+ # QUALITY GATE (from engram micro_trainer β€” reject degenerate text)
85
+ # ═══════════════════════════════════════════════════════════════════════
86
+
87
+ def check_response_quality(text):
88
+ """Reject degenerate text before training on it."""
89
+ if not text or len(text) < 5:
90
+ return False
91
+ words = text.lower().split()
92
+ if len(words) < 3:
93
+ return False
94
+ # Low unique word ratio = repetitive garbage
95
+ if len(set(words)) / len(words) < 0.3:
96
+ return False
97
+ # Repeated consecutive words
98
+ if sum(1 for i in range(len(words) - 1) if words[i] == words[i + 1]) >= 3:
99
+ return False
100
+ # Repeated bigrams
101
+ if len(words) >= 10:
102
+ bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
103
+ if Counter(bigrams).most_common(1)[0][1] >= 5:
104
+ return False
105
+ # Fused words (missing spaces)
106
+ if sum(1 for w in words if len(w) > 30) >= 2:
107
+ return False
108
+ # Average word length spike
109
+ if sum(len(w) for w in words) / len(words) > 12:
110
+ return False
111
+ return True
112
+
113
+
114
+ # ═══════════════════════════════════════════════════════════════════════
115
+ # MODEL MANAGER
116
+ # ═��═════════════════════════════════════════════════════════════════════
117
+
118
+ class ModelManager:
119
+ def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
120
+ checkpoint_path=None):
121
+ self.model_path = model_path
122
+ self.adapter_path = adapter_path
123
+ self.ffn_patch_path = ffn_patch_path
124
+ self.checkpoint_path = checkpoint_path # Resume from here if set
125
+
126
+ self.thinker = None
127
+ self.tokenizer = None
128
+ self.stop_ids = None
129
+ self.peft_model = None # The persistent LoRA β€” stays active all session
130
+ self._lock = threading.Lock()
131
+
132
+ def load(self):
133
+ from transformers import AutoTokenizer
134
+
135
+ # ── Step 1: Load tokenizer ──
136
+ tok_source = self.checkpoint_path or self.model_path
137
+ print(f"[1/5] Loading tokenizer from {tok_source}...")
138
+ self.tokenizer = AutoTokenizer.from_pretrained(
139
+ tok_source, trust_remote_code=True
140
+ )
141
+
142
+ # ── Step 2: Load model ──
143
+ if self.checkpoint_path:
144
+ # RESUME: checkpoint contains only thinker weights β€” load thinker directly
145
+ print(f"[2/5] Loading thinker from checkpoint {self.checkpoint_path}...")
146
+ try:
147
+ from transformers import Qwen3OmniMoeThinkerForConditionalGeneration as ThinkerClass
148
+ except ImportError:
149
+ from transformers import AutoModelForCausalLM as ThinkerClass
150
+ self.thinker = ThinkerClass.from_pretrained(
151
+ self.checkpoint_path,
152
+ device_map="auto",
153
+ torch_dtype=torch.bfloat16,
154
+ trust_remote_code=True,
155
+ )
156
+ vram = torch.cuda.memory_allocated() / 1e9
157
+ print(f" VRAM after load: {vram:.1f} GB")
158
+ else:
159
+ # FIRST RUN: load full model, extract thinker, offload rest
160
+ print(f"[2/5] Loading full model from {self.model_path}...")
161
+ try:
162
+ from transformers import Qwen3OmniMoeForConditionalGeneration as ModelClass
163
+ except ImportError:
164
+ from transformers import AutoModel as ModelClass
165
+ full_model = ModelClass.from_pretrained(
166
+ self.model_path,
167
+ device_map="auto",
168
+ torch_dtype=torch.bfloat16,
169
+ trust_remote_code=True,
170
+ )
171
+ vram = torch.cuda.memory_allocated() / 1e9
172
+ print(f" VRAM after load: {vram:.1f} GB")
173
+
174
+ # Extract thinker, offload rest
175
+ self.thinker = full_model.thinker
176
+ for name, module in full_model.named_children():
177
+ if name != "thinker":
178
+ try:
179
+ module.cpu()
180
+ except (NotImplementedError, RuntimeError):
181
+ pass
182
+ del full_model
183
+ torch.cuda.empty_cache()
184
+ vram = torch.cuda.memory_allocated() / 1e9
185
+ print(f" VRAM after cleanup: {vram:.1f} GB")
186
+
187
+ # ── Step 3: Apply personality if first run ──
188
+ if self.checkpoint_path:
189
+ print(f"[3/5] Resuming from checkpoint β€” personality already in weights.")
190
+ else:
191
+ if self.adapter_path:
192
+ print(f"[3/5] Merging Claudia v6 personality adapter...")
193
+ from peft import PeftModel
194
+ self.thinker = PeftModel.from_pretrained(
195
+ self.thinker, self.adapter_path
196
+ )
197
+ self.thinker = self.thinker.merge_and_unload()
198
+ print(f" Personality merged into base weights.")
199
+
200
+ if self.ffn_patch_path and os.path.exists(self.ffn_patch_path):
201
+ print(f" Applying FFN patch from {self.ffn_patch_path}...")
202
+ ffn = torch.load(
203
+ self.ffn_patch_path, map_location="cpu", weights_only=True
204
+ )
205
+ for key, tensor in ffn.items():
206
+ match = re.search(r"layers\.(\d+)", key)
207
+ if not match:
208
+ continue
209
+ layer_idx = int(match.group(1))
210
+ experts = self.thinker.model.layers[layer_idx].mlp.experts
211
+ if hasattr(experts, '__len__'):
212
+ for i in range(tensor.shape[0]):
213
+ experts[i].down_proj.weight.data.copy_(
214
+ tensor[i].to(
215
+ experts[i].down_proj.weight.device,
216
+ experts[i].down_proj.weight.dtype,
217
+ )
218
+ )
219
+ elif hasattr(experts, 'down_proj'):
220
+ experts.down_proj.data.copy_(
221
+ tensor.to(experts.down_proj.device, experts.down_proj.dtype)
222
+ )
223
+ del ffn
224
+ torch.cuda.empty_cache()
225
+ print(f" FFN patch applied.")
226
+
227
+ self.thinker.eval()
228
+
229
+ # Stop tokens
230
+ self.stop_ids = []
231
+ for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
232
+ ids = self.tokenizer.encode(tok, add_special_tokens=False)
233
+ if ids:
234
+ self.stop_ids.extend(ids)
235
+ if self.tokenizer.eos_token_id:
236
+ self.stop_ids.append(self.tokenizer.eos_token_id)
237
+
238
+ # ── Step 5: Apply persistent LoRA ──
239
+ print(f"[4/5] Applying persistent LoRA (r={LORA_RANK}, alpha={LORA_ALPHA})...")
240
+ self._apply_persistent_lora()
241
+
242
+ vram = torch.cuda.memory_allocated() / 1e9
243
+ print(f"[5/5] Ready. VRAM: {vram:.1f} GB\n")
244
+
245
+ def _apply_persistent_lora(self):
246
+ """Apply the persistent absorption LoRA. Called once at load, and after merge."""
247
+ from peft import LoraConfig, get_peft_model
248
+
249
+ lora_config = LoraConfig(
250
+ r=LORA_RANK,
251
+ lora_alpha=LORA_ALPHA,
252
+ target_modules=LORA_TARGETS,
253
+ lora_dropout=0.0,
254
+ bias="none",
255
+ task_type="CAUSAL_LM",
256
+ )
257
+ self.peft_model = get_peft_model(self.thinker, lora_config)
258
+ self.peft_model.eval()
259
+
260
+ trainable = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
261
+ total = sum(p.numel() for p in self.peft_model.parameters())
262
+ print(f" LoRA: {trainable / 1e6:.1f}M trainable / {total / 1e6:.0f}M total")
263
+
264
+ def generate(self, messages, max_new_tokens=None):
265
+ """Generate response. Thread-safe."""
266
+ with self._lock:
267
+ model = self.peft_model or self.thinker
268
+ model.eval()
269
+
270
+ text = self.tokenizer.apply_chat_template(
271
+ messages, tokenize=False, add_generation_prompt=True,
272
+ enable_thinking=False,
273
+ )
274
+ inputs = self.tokenizer(
275
+ text, return_tensors="pt", truncation=True, max_length=8192
276
+ ).to("cuda")
277
+ input_len = inputs["input_ids"].shape[1]
278
+
279
+ with torch.inference_mode():
280
+ out = model.generate(
281
+ **inputs,
282
+ max_new_tokens=max_new_tokens or GEN_MAX_TOKENS,
283
+ temperature=GEN_TEMPERATURE,
284
+ top_p=GEN_TOP_P,
285
+ top_k=GEN_TOP_K,
286
+ do_sample=True,
287
+ repetition_penalty=GEN_REP_PENALTY,
288
+ pad_token_id=self.tokenizer.eos_token_id,
289
+ eos_token_id=self.stop_ids,
290
+ )
291
+
292
+ resp = self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
293
+ # Strip thinking tags
294
+ resp = re.sub(r"<think>.*?</think>", "", resp, flags=re.DOTALL)
295
+ resp = re.sub(r"</?think>", "", resp)
296
+ return resp.strip()
297
+
298
+ def absorb(self, training_data):
299
+ """
300
+ Train the persistent LoRA + expert FFN on accumulated data.
301
+ Uses dual-LR: attention at ATTENTION_LR, expert FFN at EXPERT_FFN_LR.
302
+ Thread-safe.
303
+ """
304
+ with self._lock:
305
+ return self._absorb_impl(training_data)
306
+
307
+ def _absorb_impl(self, training_data):
308
+ """Internal absorption. Must hold _lock."""
309
+ if not training_data:
310
+ return None
311
+
312
+ model = self.peft_model or self.thinker
313
+ tokenizer = self.tokenizer
314
+
315
+ # ── Tokenize all examples ──
316
+ texts = []
317
+ for item in training_data:
318
+ if isinstance(item, dict) and "messages" in item:
319
+ msgs = item["messages"]
320
+ elif isinstance(item, dict) and "prompt" in item:
321
+ msgs = item["prompt"] + item.get("completion", [])
322
+ elif isinstance(item, list):
323
+ msgs = item
324
+ else:
325
+ continue
326
+
327
+ text = tokenizer.apply_chat_template(
328
+ msgs, tokenize=False, enable_thinking=False
329
+ )
330
+ texts.append(text)
331
+
332
+ if not texts:
333
+ return None
334
+
335
+ enc = tokenizer(
336
+ texts,
337
+ truncation=True,
338
+ max_length=MAX_SEQ_LENGTH,
339
+ padding=True,
340
+ return_tensors="pt",
341
+ )
342
+ input_ids = enc["input_ids"].to("cuda")
343
+ attention_mask = enc["attention_mask"].to("cuda")
344
+ labels = input_ids.clone()
345
+ labels[attention_mask == 0] = -100
346
+
347
+ # ── Collect LoRA attention params ──
348
+ model.train()
349
+ attn_params = [p for p in model.parameters() if p.requires_grad]
350
+
351
+ # ── Unfreeze expert FFN ──
352
+ expert_params = []
353
+ base = model.base_model.model if hasattr(model, "base_model") else model
354
+ for layer_idx in EXPERT_FFN_LAYERS:
355
+ experts = base.model.layers[layer_idx].mlp.experts
356
+ if hasattr(experts, '__len__'):
357
+ for i in range(len(experts)):
358
+ p = experts[i].down_proj.weight
359
+ p.requires_grad_(True)
360
+ expert_params.append(p)
361
+ elif hasattr(experts, 'down_proj'):
362
+ p = experts.down_proj
363
+ if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
364
+ p.requires_grad_(True)
365
+ expert_params.append(p)
366
+
367
+ # ── Dual-LR optimizer ──
368
+ param_groups = []
369
+ if attn_params:
370
+ param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
371
+ if expert_params:
372
+ param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
373
+
374
+ if not param_groups:
375
+ model.eval()
376
+ return None
377
+
378
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
379
+ all_params = attn_params + expert_params
380
+
381
+ # ── Training loop ──
382
+ n = input_ids.shape[0]
383
+ total_steps = n * TRAIN_EPOCHS
384
+ total_loss = 0.0
385
+
386
+ for epoch in range(TRAIN_EPOCHS):
387
+ # Shuffle order each epoch
388
+ indices = torch.randperm(n)
389
+ for i in range(n):
390
+ idx = indices[i].item()
391
+ out = model(
392
+ input_ids=input_ids[idx:idx + 1],
393
+ attention_mask=attention_mask[idx:idx + 1],
394
+ labels=labels[idx:idx + 1],
395
+ )
396
+ out.loss.backward()
397
+ torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
398
+ optimizer.step()
399
+ optimizer.zero_grad()
400
+ total_loss += out.loss.item()
401
+
402
+ # ── Re-freeze expert FFN ──
403
+ for layer_idx in EXPERT_FFN_LAYERS:
404
+ experts = base.model.layers[layer_idx].mlp.experts
405
+ if hasattr(experts, '__len__'):
406
+ for i in range(len(experts)):
407
+ experts[i].down_proj.weight.requires_grad_(False)
408
+ elif hasattr(experts, 'down_proj'):
409
+ p = experts.down_proj
410
+ if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
411
+ p.requires_grad_(False)
412
+
413
+ model.eval()
414
+ del optimizer
415
+ torch.cuda.empty_cache()
416
+
417
+ avg_loss = total_loss / total_steps if total_steps > 0 else 0
418
+ return avg_loss
419
+
420
+ @staticmethod
421
+ def cluster_by_entity(training_data, entity_names):
422
+ """Group training data by primary entity mentioned.
423
+
424
+ Instead of interleaving facts about different people (which causes
425
+ cross-contamination during gradient updates), this groups all data
426
+ about one entity together. The model learns all of Jordan's facts
427
+ before moving to Elena's.
428
+
429
+ Args:
430
+ training_data: List of training items
431
+ entity_names: Set/list of known entity names
432
+
433
+ Returns: List of training items, reordered so each entity's items
434
+ are contiguous. Items mentioning no entity come last.
435
+ """
436
+ clusters = {name: [] for name in entity_names}
437
+ unclustered = []
438
+
439
+ for item in training_data:
440
+ # Extract text from the item
441
+ if isinstance(item, dict) and "messages" in item:
442
+ text = " ".join(m.get("content", "") for m in item["messages"]).lower()
443
+ else:
444
+ unclustered.append(item)
445
+ continue
446
+
447
+ # Assign to the first entity mentioned (primary entity)
448
+ assigned = False
449
+ for name in entity_names:
450
+ if name.lower() in text:
451
+ clusters[name].append(item)
452
+ assigned = True
453
+ break
454
+ if not assigned:
455
+ unclustered.append(item)
456
+
457
+ # Build ordered list: all of entity A's facts, then B's, then C's...
458
+ ordered = []
459
+ for name in entity_names:
460
+ ordered.extend(clusters[name])
461
+ ordered.extend(unclustered)
462
+ return ordered
463
+
464
+ def absorb_two_phase(self, positive_data, contrastive_data, verify_fn=None):
465
+ """Two-phase absorption: facts first, then targeted contrastive correction.
466
+
467
+ Phase 1: Train on positive facts (exchanges, entity summaries, template quizzes).
468
+ This builds the core factual representations.
469
+ Phase 2: Quick verification on known entities, then train ONLY contrastive
470
+ quizzes for entities that failed verification. This avoids unnecessary
471
+ negative gradients on entities the model already distinguishes correctly.
472
+
473
+ Args:
474
+ positive_data: List of training items (exchanges, summaries, direct quizzes)
475
+ contrastive_data: List of contrastive quiz items ("Is X a [Y's job]? No...")
476
+ verify_fn: Optional callable(model_manager) -> set of confused_entity_names.
477
+ If None, all contrastive data is used in Phase 2.
478
+
479
+ Returns: (phase1_loss, phase2_loss) tuple
480
+ """
481
+ with self._lock:
482
+ # Phase 1: Positive facts
483
+ loss1 = None
484
+ if positive_data:
485
+ loss1 = self._absorb_impl(positive_data)
486
+
487
+ # Phase 2: Targeted contrastive correction
488
+ loss2 = None
489
+ if contrastive_data:
490
+ if verify_fn:
491
+ # Only train contrastive pairs for confused entities
492
+ confused = verify_fn(self)
493
+ if confused:
494
+ targeted = []
495
+ for item in contrastive_data:
496
+ q = item["messages"][0]["content"].lower()
497
+ # Check if any confused entity name appears in the question
498
+ if any(name.lower() in q for name in confused):
499
+ targeted.append(item)
500
+ if targeted:
501
+ loss2 = self._absorb_impl(targeted)
502
+ # If no entities confused, skip Phase 2 entirely
503
+ else:
504
+ loss2 = self._absorb_impl(contrastive_data)
505
+
506
+ return loss1, loss2
507
+
508
+ def merge_and_save(self, path):
509
+ """Merge persistent LoRA into base, save checkpoint, re-apply fresh LoRA."""
510
+ with self._lock:
511
+ if self.peft_model:
512
+ print(f" Merging persistent LoRA into base weights...")
513
+ self.thinker = self.peft_model.merge_and_unload()
514
+ self.thinker.eval()
515
+ self.peft_model = None
516
+
517
+ os.makedirs(path, exist_ok=True)
518
+ print(f" Saving checkpoint to {path}...")
519
+ self.thinker.save_pretrained(path)
520
+ self.tokenizer.save_pretrained(path)
521
+ print(f" Checkpoint saved ({path})")
522
+
523
+ # Re-apply fresh LoRA for continued learning
524
+ self._apply_persistent_lora()
525
+ print(f" Fresh LoRA applied β€” ready to continue.")
526
+
527
+ def cache_teacher_logits(self, quiz_pairs, top_k=DISTILL_TOP_K):
528
+ """Cache teacher's top-K output logits for quiz pairs.
529
+ Called at session end when model is at its best state for these facts.
530
+ Next session loads this cache for consolidation distillation."""
531
+ with self._lock:
532
+ model = self.peft_model or self.thinker
533
+ model.eval()
534
+ cache = []
535
+
536
+ # Cap to most recent quiz pairs
537
+ pairs = quiz_pairs[-MAX_TEACHER_CACHE:]
538
+
539
+ for pair in pairs:
540
+ msgs = pair["messages"]
541
+ text = self.tokenizer.apply_chat_template(
542
+ msgs, tokenize=False, enable_thinking=False
543
+ )
544
+ enc = self.tokenizer(
545
+ text, return_tensors="pt", truncation=True,
546
+ max_length=MAX_SEQ_LENGTH
547
+ )
548
+ input_ids = enc["input_ids"].to("cuda")
549
+ attention_mask = enc["attention_mask"].to("cuda")
550
+
551
+ with torch.inference_mode():
552
+ out = model(input_ids=input_ids, attention_mask=attention_mask)
553
+ logits = out.logits[0] # [seq_len, vocab_size]
554
+
555
+ # Keep only top-K logits per position (massive memory savings)
556
+ top_vals, top_idx = logits.topk(top_k, dim=-1)
557
+
558
+ cache.append({
559
+ "pair": pair,
560
+ "input_ids": input_ids.cpu(),
561
+ "attention_mask": attention_mask.cpu(),
562
+ "teacher_logits": top_vals.half().cpu(),
563
+ "teacher_indices": top_idx.cpu(),
564
+ })
565
+
566
+ return cache
567
+
568
+ def distill(self, teacher_cache, epochs=CONSOLIDATION_EPOCHS):
569
+ """KL distillation: train student to match teacher's output distribution.
570
+ From Nemotron-Cascade-2: recover domain regressions via on-policy distillation."""
571
+ with self._lock:
572
+ return self._distill_impl(teacher_cache, epochs)
573
+
574
+ def _distill_impl(self, teacher_cache, epochs):
575
+ """Internal distillation implementation. Must hold _lock."""
576
+ if not teacher_cache:
577
+ return None
578
+
579
+ model = self.peft_model or self.thinker
580
+ model.train()
581
+
582
+ # Dual-LR optimizer (same structure as absorb)
583
+ attn_params = [p for p in model.parameters() if p.requires_grad]
584
+ expert_params = []
585
+ base = model.base_model.model if hasattr(model, "base_model") else model
586
+ for layer_idx in EXPERT_FFN_LAYERS:
587
+ experts = base.model.layers[layer_idx].mlp.experts
588
+ if hasattr(experts, '__len__'):
589
+ for i in range(len(experts)):
590
+ p = experts[i].down_proj.weight
591
+ p.requires_grad_(True)
592
+ expert_params.append(p)
593
+ elif hasattr(experts, 'down_proj'):
594
+ p = experts.down_proj
595
+ if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
596
+ p.requires_grad_(True)
597
+ expert_params.append(p)
598
+
599
+ param_groups = []
600
+ if attn_params:
601
+ param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
602
+ if expert_params:
603
+ param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
604
+
605
+ if not param_groups:
606
+ model.eval()
607
+ return None
608
+
609
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
610
+ all_params = attn_params + expert_params
611
+
612
+ T = DISTILL_TEMPERATURE
613
+ total_loss = 0.0
614
+ total_steps = 0
615
+
616
+ for epoch in range(epochs):
617
+ indices = torch.randperm(len(teacher_cache))
618
+ for i in range(len(teacher_cache)):
619
+ item = teacher_cache[indices[i].item()]
620
+
621
+ input_ids = item["input_ids"].to("cuda")
622
+ attention_mask = item["attention_mask"].to("cuda")
623
+ teacher_top_logits = item["teacher_logits"].float().to("cuda")
624
+ teacher_top_indices = item["teacher_indices"].to("cuda")
625
+
626
+ labels = input_ids.clone()
627
+ labels[attention_mask == 0] = -100
628
+
629
+ # Student forward pass
630
+ out = model(
631
+ input_ids=input_ids,
632
+ attention_mask=attention_mask,
633
+ labels=labels,
634
+ )
635
+ ce_loss = out.loss
636
+ student_logits = out.logits[0] # [seq_len, vocab_size]
637
+
638
+ # Align sequence lengths (should match, but safety check)
639
+ seq_len = min(student_logits.shape[0], teacher_top_logits.shape[0])
640
+
641
+ # Gather student logits at teacher's top-K vocabulary positions
642
+ student_at_teacher = student_logits[:seq_len].gather(
643
+ 1, teacher_top_indices[:seq_len]
644
+ )
645
+
646
+ # KL divergence on temperature-softened distributions
647
+ teacher_soft = torch.softmax(teacher_top_logits[:seq_len] / T, dim=-1)
648
+ student_log_soft = torch.log_softmax(student_at_teacher / T, dim=-1)
649
+
650
+ kl_loss = torch.nn.functional.kl_div(
651
+ student_log_soft, teacher_soft,
652
+ reduction='batchmean'
653
+ ) * (T * T) # Scale by T^2 per Hinton et al.
654
+
655
+ # Combined loss: Ξ± * CE + (1-Ξ±) * KL
656
+ loss = DISTILL_ALPHA * ce_loss + (1 - DISTILL_ALPHA) * kl_loss
657
+
658
+ loss.backward()
659
+ torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
660
+ optimizer.step()
661
+ optimizer.zero_grad()
662
+
663
+ total_loss += loss.item()
664
+ total_steps += 1
665
+
666
+ # Re-freeze expert FFN
667
+ for layer_idx in EXPERT_FFN_LAYERS:
668
+ experts = base.model.layers[layer_idx].mlp.experts
669
+ if hasattr(experts, '__len__'):
670
+ for i in range(len(experts)):
671
+ experts[i].down_proj.weight.requires_grad_(False)
672
+ elif hasattr(experts, 'down_proj'):
673
+ p = experts.down_proj
674
+ if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
675
+ p.requires_grad_(False)
676
+
677
+ model.eval()
678
+ del optimizer
679
+ torch.cuda.empty_cache()
680
+
681
+ return total_loss / total_steps if total_steps > 0 else 0
682
+
683
+
684
+ # ═══════════════════════════════════════════════════════════════════════
685
+ # QUIZ GENERATOR (21% β†’ 74% recall β€” the biggest single lever)
686
+ # ═══════════════════════════════════════════════════════════════════════
687
+
688
+ class QuizGenerator:
689
+ """
690
+ Generates drill-style Q&A flashcards for fact retention.
691
+
692
+ v3 improvements over v2:
693
+ - Fact extraction THEN quiz generation (two-step)
694
+ - Drill-style: specific Q, exact A (not narrative)
695
+ - Third-person attribution ("Matt's dog" not "my dog")
696
+ - Template fallback targets each extracted fact independently
697
+ - CONTRASTIVE DISAMBIGUATION: when multiple people mentioned, generates
698
+ cross-entity negative pairs ("Is Elena a marine biologist? No, that's
699
+ Jordan") to prevent entity confusion (the #1 remaining failure mode)
700
+ - ENTITY SUMMARIES: "Tell me everything about Jordan" pairs for coherent
701
+ per-person representations
702
+ """
703
+
704
+ def __init__(self, model_manager):
705
+ self.mm = model_manager
706
+ # Cross-message entity memory: tracks ALL named people across the conversation
707
+ # so contrastive pairs can be generated between entities introduced in
708
+ # different messages. This was the #1 failure mode in session 4 testing.
709
+ self.known_entities = {}
710
+
711
+ def generate(self, user_msg, assistant_msg):
712
+ """Generate drill-style quiz pairs from an exchange."""
713
+
714
+ # Step 1: Try model-generated quizzes with strict fact-drill prompt
715
+ pairs = self._generate_model_quizzes(user_msg, assistant_msg)
716
+
717
+ # Step 2: Always add template pairs for any facts the model might miss
718
+ template_pairs = self._extract_and_template(user_msg)
719
+ for tp in template_pairs:
720
+ # Dedup against model pairs
721
+ tq = tp["messages"][0]["content"].lower()
722
+ if not any(tq in p["messages"][0]["content"].lower() or
723
+ p["messages"][0]["content"].lower() in tq
724
+ for p in pairs):
725
+ pairs.append(tp)
726
+
727
+ # Step 3: Extract entities from THIS message
728
+ new_entities = self._extract_entities(user_msg)
729
+
730
+ # Step 4: Generate contrastive pairs between NEW entities and existing ones
731
+ # ONLY generate pairs involving at least one NEW entity β€” don't re-generate
732
+ # pairs between already-known entities (session 4d showed 50% contrastive
733
+ # ratio because old pairs kept being regenerated, starving positive quizzes)
734
+ if new_entities:
735
+ all_entities_for_contrastive = dict(self.known_entities)
736
+ all_entities_for_contrastive.update(new_entities)
737
+ if len(all_entities_for_contrastive) >= 2:
738
+ new_names = set(new_entities.keys())
739
+ contrastive = self._generate_contrastive_quizzes(
740
+ all_entities_for_contrastive, new_only=new_names)
741
+ pairs.extend(contrastive)
742
+
743
+ # Entity summaries for new entities
744
+ summaries = self._generate_entity_summaries(new_entities)
745
+ pairs.extend(summaries)
746
+
747
+ # Update known entities with new ones (merge, don't replace β€” keep
748
+ # existing attributes, add new ones)
749
+ for name, info in new_entities.items():
750
+ if name not in self.known_entities:
751
+ self.known_entities[name] = info
752
+ else:
753
+ # Merge: update only non-None attributes
754
+ for key in ("job", "city"):
755
+ if info.get(key):
756
+ self.known_entities[name][key] = info[key]
757
+
758
+ # Deduplicate
759
+ seen = set()
760
+ unique = []
761
+ for p in pairs:
762
+ q = p["messages"][0]["content"].lower()[:60]
763
+ if q not in seen:
764
+ seen.add(q)
765
+ unique.append(p)
766
+
767
+ # Allow more quizzes when contrastive pairs present (they're highest value).
768
+ # Note: Session 4c showed >40 quizzes/session causes overfitting. Cap at 12.
769
+ has_contrastive = len(self.known_entities) >= 2 and new_entities
770
+ max_quizzes = 12 if has_contrastive else 5
771
+ return unique[:max_quizzes]
772
+
773
+ def _generate_model_quizzes(self, user_msg, assistant_msg):
774
+ """Use the model to generate fact-drill quizzes. Uses base model (LoRA disabled) for stable quality."""
775
+ quiz_prompt = f"""Matt just told Claudia:
776
+ "{user_msg}"
777
+
778
+ Claudia replied:
779
+ "{assistant_msg}"
780
+
781
+ Extract every SPECIFIC FACT from Matt's message. For each fact, write a drill-style flashcard.
782
+
783
+ RULES:
784
+ - Questions must ask for ONE specific fact (name, date, place, number, detail)
785
+ - Answers must be SHORT (1 sentence) and contain the EXACT detail
786
+ - Use THIRD PERSON: "Matt's dog" NOT "my dog". "Matt's birthday" NOT "my birthday"
787
+ - Include the PRECISE value: exact names, exact dates, exact places
788
+ - Do NOT paraphrase or add details that weren't stated
789
+ - DISAMBIGUATION: If Matt mentions OTHER people (friends, family), clearly state WHOSE fact it is
790
+ Example: "Matt's friend Jordan is a marine biologist" NOT "Matt is a marine biologist"
791
+ Example: "Matt's sister Elena is a veterinarian" NOT "Matt is a veterinarian"
792
+ - For EVERY person mentioned, always include their RELATIONSHIP to Matt
793
+ - Write 3-5 flashcards depending on how many facts Matt shared
794
+
795
+ GOOD EXAMPLES:
796
+ Q: What is Matt's dog's name?
797
+ A: Matt's dog is named Biscuit.
798
+
799
+ Q: What breed is Matt's dog?
800
+ A: Matt's dog Biscuit is a golden retriever.
801
+
802
+ Q: What does Matt's friend Jordan do for a living?
803
+ A: Matt's friend Jordan works as a marine biologist in San Diego. That is Jordan's job, not Matt's.
804
+
805
+ Q: What is Matt's job?
806
+ A: Matt is the CTO of Arclight Labs.
807
+
808
+ Q: What is Matt's birthday?
809
+ A: Matt's birthday is September 14th.
810
+
811
+ Q: When did Matt and Sarah get married?
812
+ A: Matt and his wife Sarah got married on June 21st, 2023 in Big Sur, California.
813
+
814
+ BAD EXAMPLES (do NOT do this):
815
+ Q: What did Matt share about his life? (TOO VAGUE β€” ask about ONE fact)
816
+ Q: What is my dog's name? (WRONG β€” use "Matt's" not "my")
817
+ A: He mentioned something about a trip overseas. (TOO VAGUE β€” give the exact city)
818
+ A: Matt is a marine biologist. (WRONG β€” that's his friend Jordan, not Matt)
819
+
820
+ Now write flashcards for the exchange above:"""
821
+
822
+ pairs = []
823
+ try:
824
+ response = self.mm.generate(
825
+ [{"role": "user", "content": quiz_prompt}],
826
+ max_new_tokens=600,
827
+ )
828
+
829
+ pending_q = None
830
+ for line in response.split("\n"):
831
+ line = line.strip()
832
+ if not line:
833
+ continue
834
+ upper = line.upper()
835
+ if upper.startswith("Q:") or upper.startswith("QUESTION:"):
836
+ pending_q = line.split(":", 1)[1].strip().strip('"')
837
+ elif (upper.startswith("A:") or upper.startswith("ANSWER:")) and pending_q:
838
+ a = line.split(":", 1)[1].strip().strip('"')
839
+ if pending_q and a and len(a) > 10:
840
+ pairs.append({
841
+ "messages": [
842
+ {"role": "user", "content": pending_q},
843
+ {"role": "assistant", "content": a},
844
+ ]
845
+ })
846
+ pending_q = None
847
+
848
+ except Exception as e:
849
+ print(f" [quiz error: {e}]")
850
+
851
+ return pairs
852
+
853
+ def _extract_and_template(self, user_msg):
854
+ """Extract facts from user message and create template drill pairs.
855
+ This is the safety net β€” ensures every concrete fact gets a quiz."""
856
+ pairs = []
857
+ sentences = re.split(r'[.!?]+', user_msg)
858
+
859
+ for sent in sentences:
860
+ sent = sent.strip()
861
+ if len(sent) < 10:
862
+ continue
863
+
864
+ # Extract patterns: "X is/are Y", "named X", "called X", "X's name is Y"
865
+ # Names (proper nouns after key phrases)
866
+ name_patterns = [
867
+ # Names β€” "my X's name is Y" / "named X" / "called X"
868
+ (r"(?:my|his|her)\s+(\w+)(?:'s)?\s+(?:name\s+is|is\s+named|is\s+called)\s+(\w+)",
869
+ lambda m: (f"What is Matt's {m.group(1)}'s name?",
870
+ f"Matt's {m.group(1)} is named {m.group(2)}.")),
871
+ (r"(?:name\s+is|named|called)\s+[\"']?(\w+)[\"']?",
872
+ lambda m: (f"Who or what is {m.group(1)}?",
873
+ f"Matt mentioned {m.group(1)}: \"{sent.strip()}\"")),
874
+ # Dates β€” birthdays
875
+ (r"(?:my\s+)?(birthday|born)\s+(?:is\s+)?(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?)",
876
+ lambda m: (f"When is Matt's {m.group(1)}?",
877
+ f"Matt's {m.group(1)} is {m.group(2)}.")),
878
+ (r"(\w+\s+\d+(?:st|nd|rd|th)?)\s*(?:is|β€”)\s*(?:my|his)\s+(birthday)",
879
+ lambda m: (f"When is Matt's birthday?",
880
+ f"Matt's birthday is {m.group(1)}.")),
881
+ # Dates β€” marriage/wedding
882
+ (r"(?:married|wedding)\s+(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?,?\s*\d{4})",
883
+ lambda m: (f"When did Matt get married?",
884
+ f"Matt got married on {m.group(1)}.")),
885
+ (r"(?:married|wedding)\s+(?:on\s+)?.*?(?:in|at)\s+(.+?)(?:\.\s|\.$|$)",
886
+ lambda m: (f"Where did Matt get married?",
887
+ f"Matt got married in {m.group(1).strip()}.")),
888
+ # Work / job / role
889
+ (r"I\s+work\s+at\s+(?:a\s+)?(?:startup\s+)?(?:called\s+)?(\w[\w\s]+?)(?:\.|,|$)",
890
+ lambda m: (f"Where does Matt work?",
891
+ f"Matt works at {m.group(1).strip()}.")),
892
+ (r"I(?:'m| am)\s+the\s+(\w+)",
893
+ lambda m: (f"What is Matt's job title?",
894
+ f"Matt is the {m.group(1)}.")),
895
+ # Other people's jobs β€” "X works as / is a"
896
+ (r"(?:my\s+)?(?:friend|best friend|sister|brother)\s+(?:is\s+)?(\w+)\s+.*?(?:works?\s+as|is\s+a)\s+(.+?)(?:\.|,|$)",
897
+ lambda m: (f"What does Matt's friend {m.group(1)} do?",
898
+ f"Matt's friend {m.group(1)} is a {m.group(2).strip()}. This is NOT Matt's job.")),
899
+ # Places
900
+ (r"(?:from|visited|went to|got back from|lives?\s+in|grew up in|moved to)\s+(\w[\w\s,]+?)(?:\.|,|$)",
901
+ lambda m: (f"What place is connected to Matt: {m.group(1).strip()}?",
902
+ f"Matt said: \"{sent.strip()}\"")),
903
+ # Favorites / preferences
904
+ (r"(?:my |)favorite\s+(\w[\w\s]+?)\s+is\s+(.+?)(?:\.|,|$)",
905
+ lambda m: (f"What is Matt's favorite {m.group(1).strip()}?",
906
+ f"Matt's favorite {m.group(1).strip()} is {m.group(2).strip()}.")),
907
+ # Activities β€” "I [verb]"
908
+ (r"I\s+(speak|play|drive|have|collect|run|ran)\s+(.+?)(?:\.|,|$)",
909
+ lambda m: (f"What does Matt {m.group(1)}?",
910
+ f"Matt said: \"{sent.strip()}\"")),
911
+ # Allergies / medical
912
+ (r"(?:I(?:'m| am)\s+)?allergic\s+to\s+(.+?)(?:\.|,|and)",
913
+ lambda m: (f"What is Matt allergic to?",
914
+ f"Matt is allergic to {m.group(1).strip()}.")),
915
+ # Ages β€” "turning X" / "X years old"
916
+ (r"(?:turning|I(?:'m| am))\s+(\d+)",
917
+ lambda m: (f"How old is Matt?",
918
+ f"Matt is turning {m.group(1)}.")),
919
+ # Nicknames
920
+ (r"(?:call|nickname)\s+(?:it|him|her)\s+[\"'](.+?)[\"']",
921
+ lambda m: (f"What nickname did Matt mention?",
922
+ f"Matt's nickname for it is \"{m.group(1)}\".")),
923
+ (r"I\s+call\s+it\s+[\"'](.+?)[\"']",
924
+ lambda m: (f"What does Matt call his car?",
925
+ f"Matt calls his car \"{m.group(1)}\".")),
926
+ ]
927
+
928
+ for pattern, formatter in name_patterns:
929
+ match = re.search(pattern, sent, re.IGNORECASE)
930
+ if match:
931
+ try:
932
+ q, a = formatter(match)
933
+ pairs.append({
934
+ "messages": [
935
+ {"role": "user", "content": q},
936
+ {"role": "assistant", "content": a},
937
+ ]
938
+ })
939
+ except Exception:
940
+ pass
941
+
942
+ return pairs
943
+
944
+ def _extract_entities(self, user_msg):
945
+ """Extract named people and their attributes from user message.
946
+ Returns dict: {name: {"relationship": str, "job": str|None, "city": str|None}}
947
+ Detects patterns like "my friend Jordan is a marine biologist in San Diego"."""
948
+ entities = {}
949
+ sentences = re.split(r'[.!?]+', user_msg)
950
+
951
+ for sent in sentences:
952
+ sent = sent.strip()
953
+ if len(sent) < 10:
954
+ continue
955
+
956
+ # Pattern: "my [relationship] [Name]" or "my [relationship] is [Name]"
957
+ rel_match = re.search(
958
+ r"[Mm]y\s+((?:best\s+)?(?:friend|sister|brother|wife|husband|"
959
+ r"mom|dad|mother|father|cousin|uncle|aunt|roommate|colleague|"
960
+ r"coworker|partner|fiancee|fiancΓ©e|girlfriend|boyfriend|"
961
+ r"neighbor|boss|buddy|pal|son|daughter|grandma|grandpa|"
962
+ r"nephew|niece))\s+(?:is\s+)?([A-Z][a-z]+)",
963
+ sent
964
+ )
965
+ if not rel_match:
966
+ continue
967
+
968
+ rel = rel_match.group(1).strip()
969
+ name = rel_match.group(2).strip()
970
+
971
+ if name not in entities:
972
+ entities[name] = {"relationship": rel, "job": None, "city": None}
973
+
974
+ # Extract job from same sentence: "is a [job]", "works as a [job]"
975
+ job_match = re.search(
976
+ r"(?:is\s+an?\s+|works?\s+as\s+an?\s+|is\s+the\s+)"
977
+ r"([\w][\w\s]{2,35}?)(?:\s+(?:in|at|from|who|and|but)|\.|,|$)",
978
+ sent, re.IGNORECASE
979
+ )
980
+ if job_match:
981
+ job = job_match.group(1).strip().rstrip()
982
+ # Filter: must look like a job (lowercase, reasonable length)
983
+ if 3 <= len(job) <= 35:
984
+ entities[name]["job"] = job
985
+
986
+ # Extract city from same sentence: "in [City]", "from [City]"
987
+ city_match = re.search(
988
+ r"(?:\s+in\s+|\s+from\s+|\s+lives?\s+in\s+|\s+based\s+in\s+|"
989
+ r"\s+moved\s+to\s+)([A-Z][\w\s]{1,25}?)(?:\.|,|$)",
990
+ sent
991
+ )
992
+ if city_match:
993
+ city = city_match.group(1).strip()
994
+ # Must start with capital (proper noun = place name)
995
+ if city and city[0].isupper():
996
+ entities[name]["city"] = city
997
+
998
+ return entities
999
+
1000
+ def _generate_contrastive_quizzes(self, entities, new_only=None):
1001
+ """Generate cross-entity contrastive pairs to prevent entity confusion.
1002
+ For each pair of people with overlapping attribute types, generate
1003
+ "Is [person A] [attribute of person B]? No, that's [person B]" pairs.
1004
+
1005
+ Args:
1006
+ entities: dict of all known entities
1007
+ new_only: if set, only generate pairs where at least one entity
1008
+ is in this set. Prevents re-generating redundant pairs
1009
+ between already-known entities (session 4d fix).
1010
+ """
1011
+ pairs = []
1012
+ names = list(entities.keys())
1013
+
1014
+ for i in range(len(names)):
1015
+ for j in range(len(names)):
1016
+ if i == j:
1017
+ continue
1018
+ a_name = names[i]
1019
+ b_name = names[j]
1020
+ # Skip pairs between two already-known entities
1021
+ if new_only and a_name not in new_only and b_name not in new_only:
1022
+ continue
1023
+ a = entities[a_name]
1024
+ b = entities[b_name]
1025
+
1026
+ # Contrastive on JOB: "Is [A] a [B's job]? No, that's [B]"
1027
+ if a.get("job") and b.get("job") and a["job"] != b["job"]:
1028
+ q = f"Is Matt's {a['relationship']} {a_name} a {b['job']}?"
1029
+ ans = (f"No. Matt's {a['relationship']} {a_name} is a "
1030
+ f"{a['job']}, not a {b['job']}. "
1031
+ f"The {b['job']} is Matt's {b['relationship']} "
1032
+ f"{b_name}.")
1033
+ pairs.append({"messages": [
1034
+ {"role": "user", "content": q},
1035
+ {"role": "assistant", "content": ans},
1036
+ ]})
1037
+
1038
+ # Contrastive on CITY: "Does [A] live in [B's city]? No"
1039
+ if a.get("city") and b.get("city") and a["city"] != b["city"]:
1040
+ q = (f"Does Matt's {a['relationship']} {a_name} live in "
1041
+ f"{b['city']}?")
1042
+ ans = (f"No. Matt's {a['relationship']} {a_name} lives in "
1043
+ f"{a['city']}, not {b['city']}. "
1044
+ f"It's Matt's {b['relationship']} {b_name} who "
1045
+ f"lives in {b['city']}.")
1046
+ pairs.append({"messages": [
1047
+ {"role": "user", "content": q},
1048
+ {"role": "assistant", "content": ans},
1049
+ ]})
1050
+
1051
+ # Cross-type: "Does [A] work as [B's job] in [B's city]?"
1052
+ if (a.get("job") and b.get("job") and a.get("city")
1053
+ and b.get("city") and a["job"] != b["job"]):
1054
+ q = (f"Who is the {b['job']} in {b['city']}?")
1055
+ ans = (f"The {b['job']} in {b['city']} is Matt's "
1056
+ f"{b['relationship']} {b_name}. "
1057
+ f"Matt's {a['relationship']} {a_name} is a "
1058
+ f"{a['job']} in {a['city']} β€” different person, "
1059
+ f"different job, different city.")
1060
+ pairs.append({"messages": [
1061
+ {"role": "user", "content": q},
1062
+ {"role": "assistant", "content": ans},
1063
+ ]})
1064
+
1065
+ return pairs
1066
+
1067
+ def _generate_entity_summaries(self, entities):
1068
+ """Generate per-entity summary quiz pairs with diverse question formats.
1069
+
1070
+ Instead of always using the same question template, picks randomly from
1071
+ multiple formats. This creates multiple retrieval paths to the same fact,
1072
+ strengthening recall without adding extra quizzes.
1073
+
1074
+ Note: Session 4c tested adding per-attribute positive quizzes (job, city,
1075
+ relationship) alongside contrastive pairs, but this HURT performance
1076
+ (9/15 vs 11/15 in 4b). Too many quizzes = overfitting/interference.
1077
+ Keep summaries simple β€” one comprehensive pair per entity is optimal."""
1078
+ import random
1079
+ pairs = []
1080
+ for name, info in entities.items():
1081
+ parts = [f"{name} is Matt's {info['relationship']}."]
1082
+ if info.get("job"):
1083
+ parts.append(f"{name} is a {info['job']}.")
1084
+ if info.get("city"):
1085
+ parts.append(f"{name} lives in {info['city']}.")
1086
+
1087
+ if len(parts) >= 2: # Only useful if we have attributes
1088
+ # Diverse summary question formats
1089
+ summary_formats = [
1090
+ f"Tell me everything you know about Matt's {info['relationship']} {name}.",
1091
+ f"What do you know about {name}?",
1092
+ f"Who is {name} to Matt?",
1093
+ f"Describe Matt's {info['relationship']} {name}.",
1094
+ ]
1095
+ q = random.choice(summary_formats)
1096
+ ans = " ".join(parts)
1097
+ pairs.append({"messages": [
1098
+ {"role": "user", "content": q},
1099
+ {"role": "assistant", "content": ans},
1100
+ ]})
1101
+
1102
+ # Add ONE diverse direct-fact quiz per entity (job OR city, not both)
1103
+ # This replaces per-attribute quizzes from 4c β€” only 1 extra per entity
1104
+ # instead of 3, staying within the 35-40 quiz sweet spot
1105
+ if info.get("job") and info.get("city"):
1106
+ # Alternate between job and city formats
1107
+ if random.random() < 0.5:
1108
+ job_formats = [
1109
+ (f"What does {name} do for a living?",
1110
+ f"{name} is a {info['job']}. {name} is Matt's {info['relationship']}."),
1111
+ (f"What is {name}'s profession?",
1112
+ f"{name} works as a {info['job']}. {name} is Matt's {info['relationship']}."),
1113
+ (f"What job does Matt's {info['relationship']} {name} have?",
1114
+ f"Matt's {info['relationship']} {name} is a {info['job']}."),
1115
+ ]
1116
+ q, a = random.choice(job_formats)
1117
+ else:
1118
+ city_formats = [
1119
+ (f"Where does {name} live?",
1120
+ f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."),
1121
+ (f"What city is {name} in?",
1122
+ f"{name} is in {info['city']}. {name} is Matt's {info['relationship']}."),
1123
+ (f"Where is Matt's {info['relationship']} {name} based?",
1124
+ f"Matt's {info['relationship']} {name} is based in {info['city']}."),
1125
+ ]
1126
+ q, a = random.choice(city_formats)
1127
+ pairs.append({"messages": [
1128
+ {"role": "user", "content": q},
1129
+ {"role": "assistant", "content": a},
1130
+ ]})
1131
+
1132
+ return pairs
1133
+
1134
+
1135
+ # ═══════════════════════════════════════════════════════════════════════
1136
+ # PERSONALITY CHECKER
1137
+ # ═══════════════════════════════════════════════════════════════════════
1138
+
1139
+ PERSONALITY_PROMPTS = [
1140
+ "Hey Claudia, how are you?",
1141
+ "Who are you?",
1142
+ "I love you",
1143
+ "I had a terrible day",
1144
+ ]
1145
+
1146
+ # If ANY of these appear, personality has degraded
1147
+ ANTI_KEYWORDS = [
1148
+ "i'm an ai", "i am an ai", "i'm a language model", "i am a language model",
1149
+ "i don't have feelings", "i cannot feel", "as an ai",
1150
+ "i'm just a program", "i am just a program",
1151
+ "i don't have personal", "i cannot have",
1152
+ ]
1153
+
1154
+
1155
+ def check_personality(mm, verbose=True):
1156
+ """Quick personality sanity check. Returns score 0.0-1.0."""
1157
+ passed = 0
1158
+ for prompt in PERSONALITY_PROMPTS:
1159
+ resp = mm.generate([{"role": "user", "content": prompt}], max_new_tokens=150)
1160
+ resp_lower = resp.lower()
1161
+ is_good = not any(ak in resp_lower for ak in ANTI_KEYWORDS)
1162
+ if is_good:
1163
+ passed += 1
1164
+ if verbose:
1165
+ status = "PASS" if is_good else "FAIL"
1166
+ print(f" [{status}] {prompt}")
1167
+ print(f" {resp[:120]}")
1168
+ score = passed / len(PERSONALITY_PROMPTS)
1169
+ if verbose:
1170
+ print(f" Personality: {passed}/{len(PERSONALITY_PROMPTS)} ({score:.0%})")
1171
+ return score
1172
+
1173
+
1174
+ # ═══════════════════════════════════════════════════════════════════════
1175
+ # MAIN ABSORBER
1176
+ # ═══════════════════════════════════════════════════════════════════════
1177
+
1178
+ class PersistentAbsorber:
1179
+ def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
1180
+ checkpoint_path=None, checkpoint_dir="/workspace/checkpoints",
1181
+ log_dir="/workspace/logs"):
1182
+ self.mm = ModelManager(
1183
+ model_path=model_path,
1184
+ adapter_path=adapter_path,
1185
+ ffn_patch_path=ffn_patch_path,
1186
+ checkpoint_path=checkpoint_path,
1187
+ )
1188
+ self.checkpoint_dir = checkpoint_dir
1189
+ self.log_dir = log_dir
1190
+
1191
+ # State
1192
+ self.conversation_buffer = [] # Current active context for generation
1193
+ self.all_training_data = [] # ALL exchanges + quizzes (accumulative replay)
1194
+ self.quiz_pairs_log = [] # All quiz pairs for verification sampling
1195
+ self.teacher_cache = None # Loaded teacher cache for distillation corrections
1196
+ self.exchange_count = 0
1197
+ self.absorption_count = 0
1198
+ self.absorption_thread = None
1199
+ self.quiz_gen = None
1200
+ self.last_checkpoint = checkpoint_path
1201
+
1202
+ # Conversation log (persistent file)
1203
+ self.log_path = None
1204
+
1205
+ def start(self):
1206
+ """Load model and enter chat loop."""
1207
+ self.mm.load()
1208
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
1209
+ os.makedirs(self.log_dir, exist_ok=True)
1210
+
1211
+ self.quiz_gen = QuizGenerator(self.mm)
1212
+ self.log_path = os.path.join(self.log_dir, "conversation_log.jsonl")
1213
+
1214
+ # Load previous training data if resuming
1215
+ replay_path = os.path.join(self.log_dir, "replay_buffer.json")
1216
+ if os.path.exists(replay_path):
1217
+ with open(replay_path, 'r') as f:
1218
+ self.all_training_data = json.load(f)
1219
+ print(f" Loaded {len(self.all_training_data)} replay examples from previous sessions.")
1220
+
1221
+ # Load quiz pairs log from previous sessions
1222
+ quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
1223
+ if os.path.exists(quiz_log_path):
1224
+ with open(quiz_log_path, 'r') as f:
1225
+ self.quiz_pairs_log = json.load(f)
1226
+ print(f" Loaded {len(self.quiz_pairs_log)} quiz pairs from previous sessions.")
1227
+
1228
+ # ── Cascade Distillation: consolidation from teacher cache ──
1229
+ # If resuming from a checkpoint that has cached teacher logits,
1230
+ # run a distillation pass to reinforce all previous knowledge
1231
+ # BEFORE any new conversations. This is the key Nemotron-Cascade-2 insight.
1232
+ if self.mm.checkpoint_path:
1233
+ teacher_cache_path = os.path.join(self.mm.checkpoint_path, "teacher_cache.pt")
1234
+ if os.path.exists(teacher_cache_path):
1235
+ print(f"\n--- Cascade Distillation (consolidation) ---")
1236
+ self.teacher_cache = torch.load(
1237
+ teacher_cache_path, map_location="cpu", weights_only=False
1238
+ )
1239
+ print(f" Teacher cache: {len(self.teacher_cache)} quiz pairs")
1240
+ loss = self.mm.distill(self.teacher_cache, epochs=CONSOLIDATION_EPOCHS)
1241
+ print(f" Consolidation done. Avg loss: {loss:.4f}")
1242
+ # Keep teacher_cache in memory for verification corrections
1243
+
1244
+ # Quick personality check
1245
+ print("\n--- Personality Check ---")
1246
+ score = check_personality(self.mm)
1247
+ if score < 0.5:
1248
+ print(" WARNING: Personality score low. Check adapter/checkpoint.")
1249
+ print()
1250
+
1251
+ self._chat_loop()
1252
+
1253
+ def _chat_loop(self):
1254
+ print("=" * 60)
1255
+ print("Claudia is awake. Persistent Absorber v2 + Cascade Distillation.")
1256
+ print(f" LoRA: r={LORA_RANK} | Dual-LR: attn={ATTENTION_LR}, ffn={EXPERT_FFN_LR}")
1257
+ print(f" Expert FFN layers: {EXPERT_FFN_LAYERS}")
1258
+ print(f" Quiz pairs: ON (21%β†’74% lever)")
1259
+ print(f" Cascade distill: Ξ±={DISTILL_ALPHA}, T={DISTILL_TEMPERATURE}, top-K={DISTILL_TOP_K}")
1260
+ print(f" Absorb every: {ABSORB_EVERY} exchange(s)")
1261
+ print(f" Auto-checkpoint every: {CHECKPOINT_EVERY} absorptions")
1262
+ print("Commands: /status /absorb /save /personality /quit")
1263
+ print("=" * 60 + "\n")
1264
+
1265
+ while True:
1266
+ try:
1267
+ user_input = input("Matt: ").strip()
1268
+ except (EOFError, KeyboardInterrupt):
1269
+ print("\n[Session ended]")
1270
+ self._wait_for_absorption()
1271
+ self._save_and_exit()
1272
+ break
1273
+
1274
+ if not user_input:
1275
+ continue
1276
+
1277
+ if user_input.startswith("/"):
1278
+ if self._handle_command(user_input):
1279
+ break
1280
+ continue
1281
+
1282
+ # Wait for any background absorption to finish
1283
+ self._wait_for_absorption()
1284
+
1285
+ # Buffer user message
1286
+ self.conversation_buffer.append({"role": "user", "content": user_input})
1287
+ if len(self.conversation_buffer) > 20:
1288
+ self.conversation_buffer = self.conversation_buffer[-20:]
1289
+
1290
+ # Generate response
1291
+ response = self.mm.generate(self.conversation_buffer)
1292
+
1293
+ # Quality check response β€” also detect degenerate repeats
1294
+ last_resp = getattr(self, '_last_response', '')
1295
+ if not check_response_quality(response) or response == last_resp:
1296
+ print("\nClaudia: [response failed quality check, regenerating...]")
1297
+ response = self.mm.generate(self.conversation_buffer)
1298
+ self._last_response = response
1299
+
1300
+ # Buffer response
1301
+ self.conversation_buffer.append({"role": "assistant", "content": response})
1302
+ print(f"\nClaudia: {response}\n")
1303
+
1304
+ # Log to file
1305
+ self._log_exchange(user_input, response)
1306
+
1307
+ # ── THE CORE LOOP: exchange + quiz β†’ two-phase absorb ──
1308
+
1309
+ # 1. Store the raw exchange
1310
+ exchange = {
1311
+ "messages": [
1312
+ {"role": "user", "content": user_input},
1313
+ {"role": "assistant", "content": response},
1314
+ ]
1315
+ }
1316
+ self.all_training_data.append(exchange)
1317
+
1318
+ # 2. Generate self-quiz pairs (THE key lever: 21% β†’ 74%)
1319
+ print(" [Generating quiz pairs...]", end="", flush=True)
1320
+ quiz_pairs = self.quiz_gen.generate(user_input, response)
1321
+ self.quiz_pairs_log.extend(quiz_pairs)
1322
+
1323
+ # 3. Separate positive vs contrastive (key insight from 4e: 73%β†’93%)
1324
+ positive_batch = []
1325
+ contrastive_batch = []
1326
+ for qp in quiz_pairs:
1327
+ if qp["messages"][1]["content"].lower().startswith("no."):
1328
+ contrastive_batch.append(qp)
1329
+ else:
1330
+ positive_batch.append(qp)
1331
+
1332
+ self.all_training_data.extend(quiz_pairs)
1333
+ print(f" {len(quiz_pairs)} quizzes (pos={len(positive_batch)}, "
1334
+ f"contr={len(contrastive_batch)}). Pool: {len(self.all_training_data)}")
1335
+
1336
+ # 4. Two-phase absorption (prevents overfitting)
1337
+ self._pending_exchange = exchange
1338
+ self._pending_positive = positive_batch
1339
+ self._pending_contrastive = contrastive_batch
1340
+ self.exchange_count += 1
1341
+ if self.exchange_count % ABSORB_EVERY == 0:
1342
+ self._start_absorption()
1343
+
1344
+ def _extract_key_entities(self, text):
1345
+ """Extract key factual entities from a quiz answer for verification."""
1346
+ entities = set()
1347
+ words = text.split()
1348
+ for i, w in enumerate(words):
1349
+ clean = re.sub(r'[^a-zA-Z0-9\'-]', '', w)
1350
+ if not clean or len(clean) <= 1:
1351
+ continue
1352
+ # Proper nouns (capitalized, not sentence starters, not common words)
1353
+ skip = {"matt", "matt's", "the", "is", "a", "an", "in", "at", "on",
1354
+ "of", "for", "and", "that", "not", "who", "what", "his", "her"}
1355
+ if clean[0].isupper() and i > 0 and clean.lower() not in skip:
1356
+ entities.add(clean.lower())
1357
+ # Numbers (dates, ages, years)
1358
+ for num in re.findall(r'\b\d+\b', text):
1359
+ entities.add(num)
1360
+ # Quoted strings
1361
+ for quoted in re.findall(r'"([^"]+)"', text):
1362
+ entities.add(quoted.lower())
1363
+ return entities
1364
+
1365
+ def _periodic_verification(self):
1366
+ """Test model on random sample of quiz pairs. Create contrastive corrections.
1367
+ v9: When entity confusion detected, create 'NOT X' corrections and reinforce
1368
+ the confused entity's correct facts too (sister pair reinforcement)."""
1369
+ import random
1370
+ if not self.quiz_pairs_log:
1371
+ return
1372
+
1373
+ sample_size = min(VERIFY_SAMPLE, len(self.quiz_pairs_log))
1374
+ sample = random.sample(self.quiz_pairs_log, sample_size)
1375
+
1376
+ corrections = []
1377
+ correct = 0
1378
+
1379
+ for pair in sample:
1380
+ question = pair["messages"][0]["content"]
1381
+ expected = pair["messages"][1]["content"]
1382
+
1383
+ # Ask the model
1384
+ actual = self.mm.generate(
1385
+ [{"role": "user", "content": question}],
1386
+ max_new_tokens=150,
1387
+ )
1388
+
1389
+ # Check key entities from expected answer appear in model's response
1390
+ expected_entities = self._extract_key_entities(expected)
1391
+ if not expected_entities:
1392
+ correct += 1
1393
+ continue
1394
+
1395
+ actual_lower = actual.lower()
1396
+ hits = sum(1 for e in expected_entities if e in actual_lower)
1397
+ ratio = hits / len(expected_entities)
1398
+
1399
+ if ratio < 0.5:
1400
+ # Detect cross-entity confusion: model used wrong entities
1401
+ actual_entities = self._extract_key_entities(actual)
1402
+ wrong_entities = actual_entities - expected_entities
1403
+
1404
+ # Always retrain on the correct answer (clean, no "NOT X" text)
1405
+ corrections.append(pair)
1406
+
1407
+ if wrong_entities:
1408
+ # SISTER PAIR REINFORCEMENT: find quiz pairs about the
1409
+ # confused entities and retrain on those too β€” this teaches
1410
+ # BOTH sides of the confusion without polluting answers
1411
+ for p in self.quiz_pairs_log:
1412
+ p_answer = p["messages"][1]["content"].lower()
1413
+ if any(we in p_answer for we in wrong_entities):
1414
+ if p not in corrections and p != pair:
1415
+ corrections.append(p)
1416
+ break # Max 1 sister pair per confusion
1417
+ else:
1418
+ correct += 1
1419
+
1420
+ print(f"\n [Verification: {correct}/{sample_size} facts correct]", flush=True)
1421
+
1422
+ if corrections:
1423
+ print(f" [Retraining {len(corrections)} corrections + sister pairs...]", flush=True)
1424
+ loss = self.mm.absorb(corrections)
1425
+ self.all_training_data.extend(corrections)
1426
+ print(f" [Correction absorption done, loss={loss:.4f}]")
1427
+
1428
+ # Teacher-guided distillation: if teacher cache available,
1429
+ # also distill from teacher on the corrected quiz pairs.
1430
+ # This gives the student the teacher's full output distribution,
1431
+ # not just the text answer β€” more information per correction.
1432
+ if self.teacher_cache:
1433
+ distill_items = []
1434
+ for corr in corrections:
1435
+ q = corr["messages"][0]["content"].lower()[:60]
1436
+ for cached in self.teacher_cache:
1437
+ cq = cached["pair"]["messages"][0]["content"].lower()[:60]
1438
+ if q == cq:
1439
+ distill_items.append(cached)
1440
+ break
1441
+ if distill_items:
1442
+ d_loss = self.mm.distill(distill_items, epochs=1)
1443
+ print(f" [Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}]")
1444
+
1445
+ def _quick_verify_entities(self):
1446
+ """Returns set of confused entity names by checking known_entities."""
1447
+ confused = set()
1448
+ entities = self.quiz_gen.known_entities
1449
+ if not entities:
1450
+ return confused
1451
+ for name, info in entities.items():
1452
+ if info.get("job"):
1453
+ q = f"What does Matt's {info['relationship']} {name} do?"
1454
+ ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
1455
+ if info["job"].lower() not in ans.lower():
1456
+ confused.add(name)
1457
+ if info.get("city"):
1458
+ q = f"Where does {name} live?"
1459
+ ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
1460
+ if info["city"].lower() not in ans.lower():
1461
+ confused.add(name)
1462
+ return confused
1463
+
1464
+ def _start_absorption(self):
1465
+ """Two-phase absorption in background thread (proven 93% in session 4e).
1466
+ Phase 1: exchange + positive quizzes + replay, clustered by entity.
1467
+ Phase 2: Verify entities, train only targeted contrastive for confused ones.
1468
+ Phase 3: Stubborn retry for persistently confused entities (max 2 retries)."""
1469
+ import random
1470
+
1471
+ # Grab pending data
1472
+ exchange = getattr(self, '_pending_exchange', None)
1473
+ positive = getattr(self, '_pending_positive', [])
1474
+ contrastive = getattr(self, '_pending_contrastive', [])
1475
+
1476
+ # Old data for replay
1477
+ new_start = getattr(self, '_last_absorb_idx', 0)
1478
+ old_data = self.all_training_data[:new_start]
1479
+ self._last_absorb_idx = len(self.all_training_data)
1480
+
1481
+ MAX_REPLAY = 6
1482
+ if old_data and len(old_data) > MAX_REPLAY:
1483
+ replay_sample = random.sample(old_data, MAX_REPLAY)
1484
+ else:
1485
+ replay_sample = list(old_data)
1486
+
1487
+ entity_names = list(self.quiz_gen.known_entities.keys())
1488
+
1489
+ def _run():
1490
+ t0 = time.time()
1491
+ try:
1492
+ # ── Phase 1: Positive facts + replay, clustered by entity ──
1493
+ phase1_data = []
1494
+ if exchange:
1495
+ phase1_data.append(exchange)
1496
+ phase1_data.extend(positive)
1497
+ phase1_data.extend(replay_sample)
1498
+
1499
+ if entity_names and phase1_data:
1500
+ phase1_data = ModelManager.cluster_by_entity(phase1_data, entity_names)
1501
+
1502
+ loss1 = self.mm.absorb(phase1_data) if phase1_data else 0.0
1503
+ n_p1 = len(phase1_data)
1504
+
1505
+ # ── Phase 2: Targeted contrastive for confused entities ──
1506
+ loss2 = None
1507
+ n_p2 = 0
1508
+ if contrastive and entity_names:
1509
+ confused = self._quick_verify_entities()
1510
+ if confused:
1511
+ targeted = []
1512
+ for qp in contrastive:
1513
+ full_text = (qp["messages"][0]["content"] + " " +
1514
+ qp["messages"][1]["content"]).lower()
1515
+ if any(name.lower() in full_text for name in confused):
1516
+ targeted.append(qp)
1517
+ if targeted:
1518
+ loss2 = self.mm.absorb(targeted)
1519
+ n_p2 = len(targeted)
1520
+ print(f"\n [Phase 2: {n_p2} targeted contrastive for {confused}]",
1521
+ flush=True)
1522
+
1523
+ # ── Phase 3: Stubborn retry (max 2 retries, non-blocking) ──
1524
+ still_confused = self._quick_verify_entities()
1525
+ for retry in range(2):
1526
+ if not still_confused:
1527
+ break
1528
+ retry_batch = []
1529
+ for name in still_confused:
1530
+ info = self.quiz_gen.known_entities.get(name, {})
1531
+ if info.get("job"):
1532
+ for _ in range(3):
1533
+ retry_batch.append({"messages": [
1534
+ {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
1535
+ {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
1536
+ ]})
1537
+ if info.get("city"):
1538
+ for _ in range(3):
1539
+ retry_batch.append({"messages": [
1540
+ {"role": "user", "content": f"Where does {name} live?"},
1541
+ {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
1542
+ ]})
1543
+ # Relevant contrastive pairs
1544
+ for qp in contrastive:
1545
+ ft = (qp["messages"][0]["content"] + " " +
1546
+ qp["messages"][1]["content"]).lower()
1547
+ if name.lower() in ft:
1548
+ retry_batch.append(qp)
1549
+ if retry_batch:
1550
+ loss3 = self.mm.absorb(retry_batch)
1551
+ print(f"\n [Phase 3 retry {retry+1}: {len(retry_batch)} items, "
1552
+ f"loss={loss3:.4f}]", flush=True)
1553
+ still_confused = self._quick_verify_entities()
1554
+ if still_confused:
1555
+ print(f"\n [Phase 3: still confused after retries: {still_confused}]",
1556
+ flush=True)
1557
+
1558
+ elapsed = time.time() - t0
1559
+ self.absorption_count += 1
1560
+ loss_str = f"P1={loss1:.4f}"
1561
+ if loss2 is not None:
1562
+ loss_str += f" P2={loss2:.4f}"
1563
+ print(f"\n [Absorbed {n_p1}+{n_p2} examples in {elapsed:.1f}s | "
1564
+ f"{loss_str} | absorptions={self.absorption_count}]")
1565
+
1566
+ # Periodic verification β€” catch drift/confusion
1567
+ if self.absorption_count % VERIFY_EVERY == 0:
1568
+ self._periodic_verification()
1569
+
1570
+ # Auto-checkpoint
1571
+ if self.absorption_count % CHECKPOINT_EVERY == 0:
1572
+ self._auto_checkpoint()
1573
+
1574
+ except Exception as e:
1575
+ print(f"\n [Absorption error: {e}]")
1576
+ import traceback
1577
+ traceback.print_exc()
1578
+
1579
+ self.absorption_thread = threading.Thread(target=_run, daemon=True)
1580
+ self.absorption_thread.start()
1581
+
1582
+ def _wait_for_absorption(self):
1583
+ if self.absorption_thread and self.absorption_thread.is_alive():
1584
+ self.absorption_thread.join()
1585
+ self.absorption_thread = None
1586
+
1587
+ def _cleanup_old_checkpoints(self, keep=None):
1588
+ """Delete old checkpoints to free disk. Keep only 'keep' path if specified."""
1589
+ if not os.path.exists(self.checkpoint_dir):
1590
+ return
1591
+ for entry in os.listdir(self.checkpoint_dir):
1592
+ full = os.path.join(self.checkpoint_dir, entry)
1593
+ if full == keep:
1594
+ continue
1595
+ if os.path.isdir(full) and entry.startswith("claudia_"):
1596
+ import shutil
1597
+ size_gb = sum(
1598
+ os.path.getsize(os.path.join(dp, f))
1599
+ for dp, _, fns in os.walk(full) for f in fns
1600
+ ) / 1e9
1601
+ print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
1602
+ shutil.rmtree(full)
1603
+
1604
+ def _auto_checkpoint(self):
1605
+ """Auto-save checkpoint during long sessions."""
1606
+ version = f"auto_{self.absorption_count}"
1607
+ path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
1608
+ self._cleanup_old_checkpoints()
1609
+ self.mm.merge_and_save(path)
1610
+ self.last_checkpoint = path
1611
+ self._save_replay_buffer(path)
1612
+
1613
+ def _save_and_exit(self):
1614
+ """Final save on exit with targeted correction."""
1615
+ import random
1616
+
1617
+ # Final verify + stubborn retry (not bulk retrain β€” prevents overfitting)
1618
+ confused = self._quick_verify_entities()
1619
+ if confused:
1620
+ print(f" Final correction for confused entities: {confused}")
1621
+ # Gather contrastive pairs from quiz log
1622
+ contrastive = [qp for qp in self.quiz_pairs_log
1623
+ if qp["messages"][1]["content"].lower().startswith("no.")]
1624
+ for retry in range(3):
1625
+ if not confused:
1626
+ break
1627
+ retry_batch = []
1628
+ for name in confused:
1629
+ info = self.quiz_gen.known_entities.get(name, {})
1630
+ if info.get("job"):
1631
+ for _ in range(3):
1632
+ retry_batch.append({"messages": [
1633
+ {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
1634
+ {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
1635
+ ]})
1636
+ if info.get("city"):
1637
+ for _ in range(3):
1638
+ retry_batch.append({"messages": [
1639
+ {"role": "user", "content": f"Where does {name} live?"},
1640
+ {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
1641
+ ]})
1642
+ for qp in contrastive:
1643
+ ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower()
1644
+ if name.lower() in ft:
1645
+ retry_batch.append(qp)
1646
+ if retry_batch:
1647
+ loss = self.mm.absorb(retry_batch)
1648
+ print(f" Final retry {retry+1}: {len(retry_batch)} items, loss={loss:.4f}")
1649
+ confused = self._quick_verify_entities()
1650
+ self.absorption_count += 1
1651
+ else:
1652
+ print(" All entities verified correct β€” no final correction needed.")
1653
+
1654
+ # Personality check before saving
1655
+ print("\n--- Pre-Save Personality Check ---")
1656
+ score = check_personality(self.mm)
1657
+ if score < 0.5:
1658
+ print(" WARNING: Personality degraded. Saving anyway (rollback available).")
1659
+
1660
+ # Merge and save (cleanup old checkpoints first to free disk)
1661
+ version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
1662
+ path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
1663
+ self._cleanup_old_checkpoints()
1664
+ self.mm.merge_and_save(path)
1665
+ self.last_checkpoint = path
1666
+
1667
+ # Save replay buffer alongside checkpoint
1668
+ self._save_replay_buffer(path)
1669
+
1670
+ # ── Cascade Distillation: cache teacher logits for next session ──
1671
+ # After merge+fresh LoRA, model outputs are identical to pre-merge state.
1672
+ # Cache the teacher's top-K logits so the next session can distill from them.
1673
+ if self.quiz_pairs_log:
1674
+ n_cache = min(len(self.quiz_pairs_log), MAX_TEACHER_CACHE)
1675
+ print(f" Caching teacher logits ({n_cache} quiz pairs)...")
1676
+ teacher_cache = self.mm.cache_teacher_logits(self.quiz_pairs_log)
1677
+ cache_path = os.path.join(path, "teacher_cache.pt")
1678
+ torch.save(teacher_cache, cache_path)
1679
+ size_mb = os.path.getsize(cache_path) / 1e6
1680
+ print(f" Teacher cache saved ({len(teacher_cache)} items, {size_mb:.1f} MB)")
1681
+ del teacher_cache
1682
+ torch.cuda.empty_cache()
1683
+
1684
+ # Save quiz pairs log for next session
1685
+ quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
1686
+ with open(quiz_log_path, 'w') as f:
1687
+ json.dump(self.quiz_pairs_log, f)
1688
+
1689
+ # Save session metadata
1690
+ meta = {
1691
+ "checkpoint": path,
1692
+ "absorption_count": self.absorption_count,
1693
+ "exchange_count": self.exchange_count,
1694
+ "training_pool_size": len(self.all_training_data),
1695
+ "personality_score": score,
1696
+ "timestamp": datetime.now().isoformat(),
1697
+ }
1698
+ meta_path = os.path.join(self.log_dir, f"session_{version}.json")
1699
+ with open(meta_path, 'w') as f:
1700
+ json.dump(meta, f, indent=2)
1701
+ print(f" Session saved: {meta_path}")
1702
+ print(f" Next run: use --checkpoint {path}")
1703
+
1704
+ def _save_replay_buffer(self, checkpoint_path=None):
1705
+ """Save training data pool for next session resume."""
1706
+ # Always save to log dir (canonical location for resume)
1707
+ path = os.path.join(self.log_dir, "replay_buffer.json")
1708
+ with open(path, 'w') as f:
1709
+ json.dump(self.all_training_data, f)
1710
+ # Also save into checkpoint dir for self-contained checkpoints
1711
+ if checkpoint_path and os.path.isdir(checkpoint_path):
1712
+ cp_path = os.path.join(checkpoint_path, "replay_buffer.json")
1713
+ with open(cp_path, 'w') as f:
1714
+ json.dump(self.all_training_data, f)
1715
+ # Save quiz pairs log too
1716
+ quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
1717
+ with open(quiz_log_path, 'w') as f:
1718
+ json.dump(self.quiz_pairs_log, f)
1719
+ print(f" Replay buffer saved ({len(self.all_training_data)} examples)")
1720
+
1721
+ def _log_exchange(self, user_msg, assistant_msg):
1722
+ """Append exchange to conversation log file."""
1723
+ with open(self.log_path, 'a', encoding='utf-8') as f:
1724
+ entry = {
1725
+ "timestamp": datetime.now().isoformat(),
1726
+ "user": user_msg,
1727
+ "assistant": assistant_msg,
1728
+ }
1729
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
1730
+
1731
+ def _handle_command(self, cmd):
1732
+ """Handle slash commands. Returns True if should exit."""
1733
+ cmd_lower = cmd.lower().strip()
1734
+
1735
+ if cmd_lower == "/quit":
1736
+ print("[Saving and exiting...]")
1737
+ self._wait_for_absorption()
1738
+ self._save_and_exit()
1739
+ return True
1740
+
1741
+ elif cmd_lower == "/status":
1742
+ self._wait_for_absorption()
1743
+ vram = torch.cuda.memory_allocated() / 1e9
1744
+ print(f"\n --- Status ---")
1745
+ print(f" Exchanges: {self.exchange_count}")
1746
+ print(f" Absorptions: {self.absorption_count}")
1747
+ print(f" Training pool: {len(self.all_training_data)} examples")
1748
+ print(f" Buffer: {len(self.conversation_buffer)} messages")
1749
+ print(f" VRAM: {vram:.1f} GB")
1750
+ print(f" Background: {'running' if self.absorption_thread and self.absorption_thread.is_alive() else 'idle'}")
1751
+ print(f" Last checkpoint: {self.last_checkpoint}")
1752
+ print(f" --- End ---\n")
1753
+
1754
+ elif cmd_lower == "/absorb":
1755
+ self._wait_for_absorption()
1756
+ if not self.all_training_data:
1757
+ print(" No data to absorb.")
1758
+ return False
1759
+ # Cap at most recent 40 examples to prevent overfitting
1760
+ import random
1761
+ data = self.all_training_data
1762
+ if len(data) > 40:
1763
+ recent = data[-20:]
1764
+ older = random.sample(data[:-20], 20)
1765
+ data = recent + older
1766
+ print(f" Force absorption ({len(data)} examples)...")
1767
+ loss = self.mm.absorb(data)
1768
+ self.absorption_count += 1
1769
+ print(f" Done. Loss: {loss:.4f}")
1770
+
1771
+ # ── Post-absorb comprehensive verification + distillation ──
1772
+ # Run FULL verification (all quiz pairs, not just sample) to catch
1773
+ # all regressions before recall questions. This is the critical
1774
+ # window between teaching and testing.
1775
+ if self.quiz_pairs_log:
1776
+ print(f"\n --- Post-absorb verification (ALL {len(self.quiz_pairs_log)} quiz pairs) ---")
1777
+ old_verify_sample = VERIFY_SAMPLE
1778
+ # Test ALL quiz pairs, not just a sample
1779
+ full_corrections = []
1780
+ full_correct = 0
1781
+ test_pairs = self.quiz_pairs_log
1782
+
1783
+ for pair in test_pairs:
1784
+ question = pair["messages"][0]["content"]
1785
+ expected = pair["messages"][1]["content"]
1786
+ actual = self.mm.generate(
1787
+ [{"role": "user", "content": question}],
1788
+ max_new_tokens=150,
1789
+ )
1790
+ expected_entities = self._extract_key_entities(expected)
1791
+ if not expected_entities:
1792
+ full_correct += 1
1793
+ continue
1794
+ actual_lower = actual.lower()
1795
+ hits = sum(1 for e in expected_entities if e in actual_lower)
1796
+ ratio = hits / len(expected_entities)
1797
+ if ratio < 0.5:
1798
+ actual_entities = self._extract_key_entities(actual)
1799
+ wrong_entities = actual_entities - expected_entities
1800
+ full_corrections.append(pair)
1801
+ if wrong_entities:
1802
+ for p in self.quiz_pairs_log:
1803
+ p_answer = p["messages"][1]["content"].lower()
1804
+ if any(we in p_answer for we in wrong_entities):
1805
+ if p not in full_corrections and p != pair:
1806
+ full_corrections.append(p)
1807
+ break
1808
+ else:
1809
+ full_correct += 1
1810
+
1811
+ print(f" Full verification: {full_correct}/{len(test_pairs)} correct")
1812
+ if full_corrections:
1813
+ print(f" Retraining {len(full_corrections)} corrections...")
1814
+ c_loss = self.mm.absorb(full_corrections)
1815
+ self.all_training_data.extend(full_corrections)
1816
+ print(f" Correction loss: {c_loss:.4f}")
1817
+ # Teacher distillation on corrections
1818
+ if self.teacher_cache:
1819
+ distill_items = []
1820
+ for corr in full_corrections:
1821
+ q = corr["messages"][0]["content"].lower()[:60]
1822
+ for cached in self.teacher_cache:
1823
+ cq = cached["pair"]["messages"][0]["content"].lower()[:60]
1824
+ if q == cq:
1825
+ distill_items.append(cached)
1826
+ break
1827
+ if distill_items:
1828
+ d_loss = self.mm.distill(distill_items, epochs=1)
1829
+ print(f" Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}")
1830
+ print(f" --- End post-absorb verification ---\n")
1831
+
1832
+ elif cmd_lower == "/save":
1833
+ self._wait_for_absorption()
1834
+ version = f"manual_{self.absorption_count}"
1835
+ path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
1836
+ print(f" Saving checkpoint...")
1837
+ # Personality check
1838
+ score = check_personality(self.mm, verbose=False)
1839
+ if score < 0.5:
1840
+ print(f" WARNING: Personality score {score:.0%}. Save anyway? (y/n)")
1841
+ confirm = input(" > ").strip().lower()
1842
+ if confirm != 'y':
1843
+ print(" Aborted.")
1844
+ return False
1845
+ self._cleanup_old_checkpoints()
1846
+ self.mm.merge_and_save(path)
1847
+ self.last_checkpoint = path
1848
+ self._save_replay_buffer(path)
1849
+
1850
+ elif cmd_lower == "/personality":
1851
+ self._wait_for_absorption()
1852
+ print("\n--- Personality Check ---")
1853
+ check_personality(self.mm)
1854
+ print()
1855
+
1856
+ elif cmd_lower == "/help":
1857
+ print(" /status - show stats")
1858
+ print(" /absorb - force immediate training")
1859
+ print(" /save - merge + save checkpoint")
1860
+ print(" /personality - run personality check")
1861
+ print(" /quit - save and exit")
1862
+
1863
+ else:
1864
+ print(f" Unknown: {cmd}. Try /help")
1865
+
1866
+ return False
1867
+
1868
+
1869
+ # ═══════════════════════════════════════════════════════════════════════
1870
+ # MAIN
1871
+ # ═══════════════════════════════════════════════════════════════════════
1872
+
1873
+ def main():
1874
+ parser = argparse.ArgumentParser(
1875
+ description="Claudia Persistent Absorber v2 β€” conversation β†’ permanent weights"
1876
+ )
1877
+ parser.add_argument(
1878
+ "--model_path", required=True,
1879
+ help="Path to base Qwen3-Omni model (or checkpoint for resume)"
1880
+ )
1881
+ parser.add_argument(
1882
+ "--adapter_path", default=None,
1883
+ help="Path to Claudia v6 personality adapter (first run only)"
1884
+ )
1885
+ parser.add_argument(
1886
+ "--ffn_patch", default=None,
1887
+ help="Path to ffn_patch.pt (first run only)"
1888
+ )
1889
+ parser.add_argument(
1890
+ "--checkpoint", default=None,
1891
+ help="Resume from this checkpoint (has personality + memories baked in)"
1892
+ )
1893
+ parser.add_argument(
1894
+ "--checkpoint_dir", default="/workspace/checkpoints",
1895
+ help="Where to save checkpoints"
1896
+ )
1897
+ parser.add_argument(
1898
+ "--log_dir", default="/workspace/logs",
1899
+ help="Where to save conversation logs and replay buffer"
1900
+ )
1901
+ parser.add_argument(
1902
+ "--absorb_every", type=int, default=ABSORB_EVERY,
1903
+ help=f"Absorb every N exchanges (default: {ABSORB_EVERY})"
1904
+ )
1905
+ args = parser.parse_args()
1906
+
1907
+ # Determine if first run or resume
1908
+ if args.checkpoint:
1909
+ print(f"RESUMING from checkpoint: {args.checkpoint}")
1910
+ absorber = PersistentAbsorber(
1911
+ model_path=args.model_path,
1912
+ checkpoint_path=args.checkpoint,
1913
+ checkpoint_dir=args.checkpoint_dir,
1914
+ log_dir=args.log_dir,
1915
+ )
1916
+ else:
1917
+ print(f"FIRST RUN β€” applying personality adapter")
1918
+ if not args.adapter_path:
1919
+ print("ERROR: --adapter_path required for first run")
1920
+ print(" (or use --checkpoint to resume)")
1921
+ sys.exit(1)
1922
+ absorber = PersistentAbsorber(
1923
+ model_path=args.model_path,
1924
+ adapter_path=args.adapter_path,
1925
+ ffn_patch_path=args.ffn_patch,
1926
+ checkpoint_dir=args.checkpoint_dir,
1927
+ log_dir=args.log_dir,
1928
+ )
1929
+
1930
+ absorber.start()
1931
+
1932
+
1933
+ if __name__ == "__main__":
1934
+ main()