| """
|
| Claudia Persistent Absorber v2
|
| ==============================
|
| Combines the 3 best proven techniques into one system:
|
|
|
| 1. SELF-QUIZ PAIRS (21% β 74% recall β the single biggest lever)
|
| 2. PERSISTENT LoRA rank 128 (89% across 25 convos, no merge-between-rounds tax)
|
| 3. DUAL-LR EXPERT FFN (attention=6e-5, FFN=3e-4 β facts into MoE experts)
|
|
|
| Architecture:
|
| - Load base Omni β thinker to GPU, rest to CPU
|
| - First run: apply Claudia v6 adapter β merge β apply FFN patch
|
| - Resume: load from checkpoint (already has personality + memories)
|
| - Apply ONE persistent LoRA (r=128, alpha=256, attention q/k/v/o)
|
| - Chat loop: generate β quiz β train (LoRA + expert FFN) β repeat
|
| - On save/quit: merge_and_unload β save full checkpoint
|
| - Next session loads from checkpoint β memories are permanent
|
|
|
| Instance: Vast.ai 33093662 (A100 80GB, Sweden)
|
| SSH: ssh -p 13662 root@ssh1.vast.ai
|
| """
|
|
|
| import argparse
|
| import gc
|
| import json
|
| import os
|
| import re
|
| import sys
|
| import threading
|
| import time
|
| import torch
|
| from collections import Counter
|
| from datetime import datetime
|
| from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| LORA_RANK = 128
|
| LORA_ALPHA = 256
|
| LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
|
|
|
| ATTENTION_LR = 6e-5
|
| EXPERT_FFN_LR = 3e-4
|
| EXPERT_FFN_LAYERS = [20, 24, 28]
|
|
|
|
|
| TRAIN_EPOCHS = 2
|
| MAX_SEQ_LENGTH = 2048
|
| GRADIENT_CLIP = 1.0
|
|
|
|
|
| GEN_TEMPERATURE = 0.7
|
| GEN_TOP_P = 0.9
|
| GEN_TOP_K = 50
|
| GEN_MAX_TOKENS = 512
|
| GEN_REP_PENALTY = 1.1
|
|
|
|
|
| ABSORB_EVERY = 1
|
|
|
|
|
| CHECKPOINT_EVERY = 10
|
|
|
|
|
| VERIFY_EVERY = 3
|
| VERIFY_SAMPLE = 10
|
|
|
|
|
|
|
|
|
| DISTILL_ALPHA = 0.5
|
| DISTILL_TEMPERATURE = 2.0
|
| DISTILL_TOP_K = 32
|
| CONSOLIDATION_EPOCHS = 2
|
| MAX_TEACHER_CACHE = 200
|
|
|
|
|
|
|
|
|
|
|
|
|
| def check_response_quality(text):
|
| """Reject degenerate text before training on it."""
|
| if not text or len(text) < 5:
|
| return False
|
| words = text.lower().split()
|
| if len(words) < 3:
|
| return False
|
|
|
| if len(set(words)) / len(words) < 0.3:
|
| return False
|
|
|
| if sum(1 for i in range(len(words) - 1) if words[i] == words[i + 1]) >= 3:
|
| return False
|
|
|
| if len(words) >= 10:
|
| bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words) - 1)]
|
| if Counter(bigrams).most_common(1)[0][1] >= 5:
|
| return False
|
|
|
| if sum(1 for w in words if len(w) > 30) >= 2:
|
| return False
|
|
|
| if sum(len(w) for w in words) / len(words) > 12:
|
| return False
|
| return True
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ModelManager:
|
| def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
|
| checkpoint_path=None):
|
| self.model_path = model_path
|
| self.adapter_path = adapter_path
|
| self.ffn_patch_path = ffn_patch_path
|
| self.checkpoint_path = checkpoint_path
|
|
|
| self.thinker = None
|
| self.tokenizer = None
|
| self.stop_ids = None
|
| self.peft_model = None
|
| self._lock = threading.Lock()
|
|
|
| def load(self):
|
| from transformers import AutoTokenizer
|
|
|
|
|
| tok_source = self.checkpoint_path or self.model_path
|
| print(f"[1/5] Loading tokenizer from {tok_source}...")
|
| self.tokenizer = AutoTokenizer.from_pretrained(
|
| tok_source, trust_remote_code=True
|
| )
|
|
|
|
|
| if self.checkpoint_path:
|
|
|
| print(f"[2/5] Loading thinker from checkpoint {self.checkpoint_path}...")
|
| try:
|
| from transformers import Qwen3OmniMoeThinkerForConditionalGeneration as ThinkerClass
|
| except ImportError:
|
| from transformers import AutoModelForCausalLM as ThinkerClass
|
| self.thinker = ThinkerClass.from_pretrained(
|
| self.checkpoint_path,
|
| device_map="auto",
|
| torch_dtype=torch.bfloat16,
|
| trust_remote_code=True,
|
| )
|
| vram = torch.cuda.memory_allocated() / 1e9
|
| print(f" VRAM after load: {vram:.1f} GB")
|
| else:
|
|
|
| print(f"[2/5] Loading full model from {self.model_path}...")
|
| try:
|
| from transformers import Qwen3OmniMoeForConditionalGeneration as ModelClass
|
| except ImportError:
|
| from transformers import AutoModel as ModelClass
|
| full_model = ModelClass.from_pretrained(
|
| self.model_path,
|
| device_map="auto",
|
| torch_dtype=torch.bfloat16,
|
| trust_remote_code=True,
|
| )
|
| vram = torch.cuda.memory_allocated() / 1e9
|
| print(f" VRAM after load: {vram:.1f} GB")
|
|
|
|
|
| self.thinker = full_model.thinker
|
| for name, module in full_model.named_children():
|
| if name != "thinker":
|
| try:
|
| module.cpu()
|
| except (NotImplementedError, RuntimeError):
|
| pass
|
| del full_model
|
| torch.cuda.empty_cache()
|
| vram = torch.cuda.memory_allocated() / 1e9
|
| print(f" VRAM after cleanup: {vram:.1f} GB")
|
|
|
|
|
| if self.checkpoint_path:
|
| print(f"[3/5] Resuming from checkpoint β personality already in weights.")
|
| else:
|
| if self.adapter_path:
|
| print(f"[3/5] Merging Claudia v6 personality adapter...")
|
| from peft import PeftModel
|
| self.thinker = PeftModel.from_pretrained(
|
| self.thinker, self.adapter_path
|
| )
|
| self.thinker = self.thinker.merge_and_unload()
|
| print(f" Personality merged into base weights.")
|
|
|
| if self.ffn_patch_path and os.path.exists(self.ffn_patch_path):
|
| print(f" Applying FFN patch from {self.ffn_patch_path}...")
|
| ffn = torch.load(
|
| self.ffn_patch_path, map_location="cpu", weights_only=True
|
| )
|
| for key, tensor in ffn.items():
|
| match = re.search(r"layers\.(\d+)", key)
|
| if not match:
|
| continue
|
| layer_idx = int(match.group(1))
|
| experts = self.thinker.model.layers[layer_idx].mlp.experts
|
| if hasattr(experts, '__len__'):
|
| for i in range(tensor.shape[0]):
|
| experts[i].down_proj.weight.data.copy_(
|
| tensor[i].to(
|
| experts[i].down_proj.weight.device,
|
| experts[i].down_proj.weight.dtype,
|
| )
|
| )
|
| elif hasattr(experts, 'down_proj'):
|
| experts.down_proj.data.copy_(
|
| tensor.to(experts.down_proj.device, experts.down_proj.dtype)
|
| )
|
| del ffn
|
| torch.cuda.empty_cache()
|
| print(f" FFN patch applied.")
|
|
|
| self.thinker.eval()
|
|
|
|
|
| self.stop_ids = []
|
| for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
|
| ids = self.tokenizer.encode(tok, add_special_tokens=False)
|
| if ids:
|
| self.stop_ids.extend(ids)
|
| if self.tokenizer.eos_token_id:
|
| self.stop_ids.append(self.tokenizer.eos_token_id)
|
|
|
|
|
| print(f"[4/5] Applying persistent LoRA (r={LORA_RANK}, alpha={LORA_ALPHA})...")
|
| self._apply_persistent_lora()
|
|
|
| vram = torch.cuda.memory_allocated() / 1e9
|
| print(f"[5/5] Ready. VRAM: {vram:.1f} GB\n")
|
|
|
| def _apply_persistent_lora(self):
|
| """Apply the persistent absorption LoRA. Called once at load, and after merge."""
|
| from peft import LoraConfig, get_peft_model
|
|
|
| lora_config = LoraConfig(
|
| r=LORA_RANK,
|
| lora_alpha=LORA_ALPHA,
|
| target_modules=LORA_TARGETS,
|
| lora_dropout=0.0,
|
| bias="none",
|
| task_type="CAUSAL_LM",
|
| )
|
| self.peft_model = get_peft_model(self.thinker, lora_config)
|
| self.peft_model.eval()
|
|
|
| trainable = sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad)
|
| total = sum(p.numel() for p in self.peft_model.parameters())
|
| print(f" LoRA: {trainable / 1e6:.1f}M trainable / {total / 1e6:.0f}M total")
|
|
|
| def generate(self, messages, max_new_tokens=None):
|
| """Generate response. Thread-safe."""
|
| with self._lock:
|
| model = self.peft_model or self.thinker
|
| model.eval()
|
|
|
| text = self.tokenizer.apply_chat_template(
|
| messages, tokenize=False, add_generation_prompt=True,
|
| enable_thinking=False,
|
| )
|
| inputs = self.tokenizer(
|
| text, return_tensors="pt", truncation=True, max_length=8192
|
| ).to("cuda")
|
| input_len = inputs["input_ids"].shape[1]
|
|
|
| with torch.inference_mode():
|
| out = model.generate(
|
| **inputs,
|
| max_new_tokens=max_new_tokens or GEN_MAX_TOKENS,
|
| temperature=GEN_TEMPERATURE,
|
| top_p=GEN_TOP_P,
|
| top_k=GEN_TOP_K,
|
| do_sample=True,
|
| repetition_penalty=GEN_REP_PENALTY,
|
| pad_token_id=self.tokenizer.eos_token_id,
|
| eos_token_id=self.stop_ids,
|
| )
|
|
|
| resp = self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
|
|
|
| resp = re.sub(r"<think>.*?</think>", "", resp, flags=re.DOTALL)
|
| resp = re.sub(r"</?think>", "", resp)
|
| return resp.strip()
|
|
|
| def absorb(self, training_data):
|
| """
|
| Train the persistent LoRA + expert FFN on accumulated data.
|
| Uses dual-LR: attention at ATTENTION_LR, expert FFN at EXPERT_FFN_LR.
|
| Thread-safe.
|
| """
|
| with self._lock:
|
| return self._absorb_impl(training_data)
|
|
|
| def _absorb_impl(self, training_data):
|
| """Internal absorption. Must hold _lock."""
|
| if not training_data:
|
| return None
|
|
|
| model = self.peft_model or self.thinker
|
| tokenizer = self.tokenizer
|
|
|
|
|
| texts = []
|
| for item in training_data:
|
| if isinstance(item, dict) and "messages" in item:
|
| msgs = item["messages"]
|
| elif isinstance(item, dict) and "prompt" in item:
|
| msgs = item["prompt"] + item.get("completion", [])
|
| elif isinstance(item, list):
|
| msgs = item
|
| else:
|
| continue
|
|
|
| text = tokenizer.apply_chat_template(
|
| msgs, tokenize=False, enable_thinking=False
|
| )
|
| texts.append(text)
|
|
|
| if not texts:
|
| return None
|
|
|
| enc = tokenizer(
|
| texts,
|
| truncation=True,
|
| max_length=MAX_SEQ_LENGTH,
|
| padding=True,
|
| return_tensors="pt",
|
| )
|
| input_ids = enc["input_ids"].to("cuda")
|
| attention_mask = enc["attention_mask"].to("cuda")
|
| labels = input_ids.clone()
|
| labels[attention_mask == 0] = -100
|
|
|
|
|
| model.train()
|
| attn_params = [p for p in model.parameters() if p.requires_grad]
|
|
|
|
|
| expert_params = []
|
| base = model.base_model.model if hasattr(model, "base_model") else model
|
| for layer_idx in EXPERT_FFN_LAYERS:
|
| experts = base.model.layers[layer_idx].mlp.experts
|
| if hasattr(experts, '__len__'):
|
| for i in range(len(experts)):
|
| p = experts[i].down_proj.weight
|
| p.requires_grad_(True)
|
| expert_params.append(p)
|
| elif hasattr(experts, 'down_proj'):
|
| p = experts.down_proj
|
| if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| p.requires_grad_(True)
|
| expert_params.append(p)
|
|
|
|
|
| param_groups = []
|
| if attn_params:
|
| param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
|
| if expert_params:
|
| param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
|
|
|
| if not param_groups:
|
| model.eval()
|
| return None
|
|
|
| optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
|
| all_params = attn_params + expert_params
|
|
|
|
|
| n = input_ids.shape[0]
|
| total_steps = n * TRAIN_EPOCHS
|
| total_loss = 0.0
|
|
|
| for epoch in range(TRAIN_EPOCHS):
|
|
|
| indices = torch.randperm(n)
|
| for i in range(n):
|
| idx = indices[i].item()
|
| out = model(
|
| input_ids=input_ids[idx:idx + 1],
|
| attention_mask=attention_mask[idx:idx + 1],
|
| labels=labels[idx:idx + 1],
|
| )
|
| out.loss.backward()
|
| torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| total_loss += out.loss.item()
|
|
|
|
|
| for layer_idx in EXPERT_FFN_LAYERS:
|
| experts = base.model.layers[layer_idx].mlp.experts
|
| if hasattr(experts, '__len__'):
|
| for i in range(len(experts)):
|
| experts[i].down_proj.weight.requires_grad_(False)
|
| elif hasattr(experts, 'down_proj'):
|
| p = experts.down_proj
|
| if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| p.requires_grad_(False)
|
|
|
| model.eval()
|
| del optimizer
|
| torch.cuda.empty_cache()
|
|
|
| avg_loss = total_loss / total_steps if total_steps > 0 else 0
|
| return avg_loss
|
|
|
| @staticmethod
|
| def cluster_by_entity(training_data, entity_names):
|
| """Group training data by primary entity mentioned.
|
|
|
| Instead of interleaving facts about different people (which causes
|
| cross-contamination during gradient updates), this groups all data
|
| about one entity together. The model learns all of Jordan's facts
|
| before moving to Elena's.
|
|
|
| Args:
|
| training_data: List of training items
|
| entity_names: Set/list of known entity names
|
|
|
| Returns: List of training items, reordered so each entity's items
|
| are contiguous. Items mentioning no entity come last.
|
| """
|
| clusters = {name: [] for name in entity_names}
|
| unclustered = []
|
|
|
| for item in training_data:
|
|
|
| if isinstance(item, dict) and "messages" in item:
|
| text = " ".join(m.get("content", "") for m in item["messages"]).lower()
|
| else:
|
| unclustered.append(item)
|
| continue
|
|
|
|
|
| assigned = False
|
| for name in entity_names:
|
| if name.lower() in text:
|
| clusters[name].append(item)
|
| assigned = True
|
| break
|
| if not assigned:
|
| unclustered.append(item)
|
|
|
|
|
| ordered = []
|
| for name in entity_names:
|
| ordered.extend(clusters[name])
|
| ordered.extend(unclustered)
|
| return ordered
|
|
|
| def absorb_two_phase(self, positive_data, contrastive_data, verify_fn=None):
|
| """Two-phase absorption: facts first, then targeted contrastive correction.
|
|
|
| Phase 1: Train on positive facts (exchanges, entity summaries, template quizzes).
|
| This builds the core factual representations.
|
| Phase 2: Quick verification on known entities, then train ONLY contrastive
|
| quizzes for entities that failed verification. This avoids unnecessary
|
| negative gradients on entities the model already distinguishes correctly.
|
|
|
| Args:
|
| positive_data: List of training items (exchanges, summaries, direct quizzes)
|
| contrastive_data: List of contrastive quiz items ("Is X a [Y's job]? No...")
|
| verify_fn: Optional callable(model_manager) -> set of confused_entity_names.
|
| If None, all contrastive data is used in Phase 2.
|
|
|
| Returns: (phase1_loss, phase2_loss) tuple
|
| """
|
| with self._lock:
|
|
|
| loss1 = None
|
| if positive_data:
|
| loss1 = self._absorb_impl(positive_data)
|
|
|
|
|
| loss2 = None
|
| if contrastive_data:
|
| if verify_fn:
|
|
|
| confused = verify_fn(self)
|
| if confused:
|
| targeted = []
|
| for item in contrastive_data:
|
| q = item["messages"][0]["content"].lower()
|
|
|
| if any(name.lower() in q for name in confused):
|
| targeted.append(item)
|
| if targeted:
|
| loss2 = self._absorb_impl(targeted)
|
|
|
| else:
|
| loss2 = self._absorb_impl(contrastive_data)
|
|
|
| return loss1, loss2
|
|
|
| def merge_and_save(self, path):
|
| """Merge persistent LoRA into base, save checkpoint, re-apply fresh LoRA."""
|
| with self._lock:
|
| if self.peft_model:
|
| print(f" Merging persistent LoRA into base weights...")
|
| self.thinker = self.peft_model.merge_and_unload()
|
| self.thinker.eval()
|
| self.peft_model = None
|
|
|
| os.makedirs(path, exist_ok=True)
|
| print(f" Saving checkpoint to {path}...")
|
| self.thinker.save_pretrained(path)
|
| self.tokenizer.save_pretrained(path)
|
| print(f" Checkpoint saved ({path})")
|
|
|
|
|
| self._apply_persistent_lora()
|
| print(f" Fresh LoRA applied β ready to continue.")
|
|
|
| def cache_teacher_logits(self, quiz_pairs, top_k=DISTILL_TOP_K):
|
| """Cache teacher's top-K output logits for quiz pairs.
|
| Called at session end when model is at its best state for these facts.
|
| Next session loads this cache for consolidation distillation."""
|
| with self._lock:
|
| model = self.peft_model or self.thinker
|
| model.eval()
|
| cache = []
|
|
|
|
|
| pairs = quiz_pairs[-MAX_TEACHER_CACHE:]
|
|
|
| for pair in pairs:
|
| msgs = pair["messages"]
|
| text = self.tokenizer.apply_chat_template(
|
| msgs, tokenize=False, enable_thinking=False
|
| )
|
| enc = self.tokenizer(
|
| text, return_tensors="pt", truncation=True,
|
| max_length=MAX_SEQ_LENGTH
|
| )
|
| input_ids = enc["input_ids"].to("cuda")
|
| attention_mask = enc["attention_mask"].to("cuda")
|
|
|
| with torch.inference_mode():
|
| out = model(input_ids=input_ids, attention_mask=attention_mask)
|
| logits = out.logits[0]
|
|
|
|
|
| top_vals, top_idx = logits.topk(top_k, dim=-1)
|
|
|
| cache.append({
|
| "pair": pair,
|
| "input_ids": input_ids.cpu(),
|
| "attention_mask": attention_mask.cpu(),
|
| "teacher_logits": top_vals.half().cpu(),
|
| "teacher_indices": top_idx.cpu(),
|
| })
|
|
|
| return cache
|
|
|
| def distill(self, teacher_cache, epochs=CONSOLIDATION_EPOCHS):
|
| """KL distillation: train student to match teacher's output distribution.
|
| From Nemotron-Cascade-2: recover domain regressions via on-policy distillation."""
|
| with self._lock:
|
| return self._distill_impl(teacher_cache, epochs)
|
|
|
| def _distill_impl(self, teacher_cache, epochs):
|
| """Internal distillation implementation. Must hold _lock."""
|
| if not teacher_cache:
|
| return None
|
|
|
| model = self.peft_model or self.thinker
|
| model.train()
|
|
|
|
|
| attn_params = [p for p in model.parameters() if p.requires_grad]
|
| expert_params = []
|
| base = model.base_model.model if hasattr(model, "base_model") else model
|
| for layer_idx in EXPERT_FFN_LAYERS:
|
| experts = base.model.layers[layer_idx].mlp.experts
|
| if hasattr(experts, '__len__'):
|
| for i in range(len(experts)):
|
| p = experts[i].down_proj.weight
|
| p.requires_grad_(True)
|
| expert_params.append(p)
|
| elif hasattr(experts, 'down_proj'):
|
| p = experts.down_proj
|
| if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| p.requires_grad_(True)
|
| expert_params.append(p)
|
|
|
| param_groups = []
|
| if attn_params:
|
| param_groups.append({"params": attn_params, "lr": ATTENTION_LR})
|
| if expert_params:
|
| param_groups.append({"params": expert_params, "lr": EXPERT_FFN_LR})
|
|
|
| if not param_groups:
|
| model.eval()
|
| return None
|
|
|
| optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
|
| all_params = attn_params + expert_params
|
|
|
| T = DISTILL_TEMPERATURE
|
| total_loss = 0.0
|
| total_steps = 0
|
|
|
| for epoch in range(epochs):
|
| indices = torch.randperm(len(teacher_cache))
|
| for i in range(len(teacher_cache)):
|
| item = teacher_cache[indices[i].item()]
|
|
|
| input_ids = item["input_ids"].to("cuda")
|
| attention_mask = item["attention_mask"].to("cuda")
|
| teacher_top_logits = item["teacher_logits"].float().to("cuda")
|
| teacher_top_indices = item["teacher_indices"].to("cuda")
|
|
|
| labels = input_ids.clone()
|
| labels[attention_mask == 0] = -100
|
|
|
|
|
| out = model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| labels=labels,
|
| )
|
| ce_loss = out.loss
|
| student_logits = out.logits[0]
|
|
|
|
|
| seq_len = min(student_logits.shape[0], teacher_top_logits.shape[0])
|
|
|
|
|
| student_at_teacher = student_logits[:seq_len].gather(
|
| 1, teacher_top_indices[:seq_len]
|
| )
|
|
|
|
|
| teacher_soft = torch.softmax(teacher_top_logits[:seq_len] / T, dim=-1)
|
| student_log_soft = torch.log_softmax(student_at_teacher / T, dim=-1)
|
|
|
| kl_loss = torch.nn.functional.kl_div(
|
| student_log_soft, teacher_soft,
|
| reduction='batchmean'
|
| ) * (T * T)
|
|
|
|
|
| loss = DISTILL_ALPHA * ce_loss + (1 - DISTILL_ALPHA) * kl_loss
|
|
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(all_params, GRADIENT_CLIP)
|
| optimizer.step()
|
| optimizer.zero_grad()
|
|
|
| total_loss += loss.item()
|
| total_steps += 1
|
|
|
|
|
| for layer_idx in EXPERT_FFN_LAYERS:
|
| experts = base.model.layers[layer_idx].mlp.experts
|
| if hasattr(experts, '__len__'):
|
| for i in range(len(experts)):
|
| experts[i].down_proj.weight.requires_grad_(False)
|
| elif hasattr(experts, 'down_proj'):
|
| p = experts.down_proj
|
| if isinstance(p, (torch.nn.Parameter, torch.Tensor)):
|
| p.requires_grad_(False)
|
|
|
| model.eval()
|
| del optimizer
|
| torch.cuda.empty_cache()
|
|
|
| return total_loss / total_steps if total_steps > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| class QuizGenerator:
|
| """
|
| Generates drill-style Q&A flashcards for fact retention.
|
|
|
| v3 improvements over v2:
|
| - Fact extraction THEN quiz generation (two-step)
|
| - Drill-style: specific Q, exact A (not narrative)
|
| - Third-person attribution ("Matt's dog" not "my dog")
|
| - Template fallback targets each extracted fact independently
|
| - CONTRASTIVE DISAMBIGUATION: when multiple people mentioned, generates
|
| cross-entity negative pairs ("Is Elena a marine biologist? No, that's
|
| Jordan") to prevent entity confusion (the #1 remaining failure mode)
|
| - ENTITY SUMMARIES: "Tell me everything about Jordan" pairs for coherent
|
| per-person representations
|
| """
|
|
|
| def __init__(self, model_manager):
|
| self.mm = model_manager
|
|
|
|
|
|
|
| self.known_entities = {}
|
|
|
| def generate(self, user_msg, assistant_msg):
|
| """Generate drill-style quiz pairs from an exchange."""
|
|
|
|
|
| pairs = self._generate_model_quizzes(user_msg, assistant_msg)
|
|
|
|
|
| template_pairs = self._extract_and_template(user_msg)
|
| for tp in template_pairs:
|
|
|
| tq = tp["messages"][0]["content"].lower()
|
| if not any(tq in p["messages"][0]["content"].lower() or
|
| p["messages"][0]["content"].lower() in tq
|
| for p in pairs):
|
| pairs.append(tp)
|
|
|
|
|
| new_entities = self._extract_entities(user_msg)
|
|
|
|
|
|
|
|
|
|
|
| if new_entities:
|
| all_entities_for_contrastive = dict(self.known_entities)
|
| all_entities_for_contrastive.update(new_entities)
|
| if len(all_entities_for_contrastive) >= 2:
|
| new_names = set(new_entities.keys())
|
| contrastive = self._generate_contrastive_quizzes(
|
| all_entities_for_contrastive, new_only=new_names)
|
| pairs.extend(contrastive)
|
|
|
|
|
| summaries = self._generate_entity_summaries(new_entities)
|
| pairs.extend(summaries)
|
|
|
|
|
|
|
| for name, info in new_entities.items():
|
| if name not in self.known_entities:
|
| self.known_entities[name] = info
|
| else:
|
|
|
| for key in ("job", "city"):
|
| if info.get(key):
|
| self.known_entities[name][key] = info[key]
|
|
|
|
|
| seen = set()
|
| unique = []
|
| for p in pairs:
|
| q = p["messages"][0]["content"].lower()[:60]
|
| if q not in seen:
|
| seen.add(q)
|
| unique.append(p)
|
|
|
|
|
|
|
| has_contrastive = len(self.known_entities) >= 2 and new_entities
|
| max_quizzes = 12 if has_contrastive else 5
|
| return unique[:max_quizzes]
|
|
|
| def _generate_model_quizzes(self, user_msg, assistant_msg):
|
| """Use the model to generate fact-drill quizzes. Uses base model (LoRA disabled) for stable quality."""
|
| quiz_prompt = f"""Matt just told Claudia:
|
| "{user_msg}"
|
|
|
| Claudia replied:
|
| "{assistant_msg}"
|
|
|
| Extract every SPECIFIC FACT from Matt's message. For each fact, write a drill-style flashcard.
|
|
|
| RULES:
|
| - Questions must ask for ONE specific fact (name, date, place, number, detail)
|
| - Answers must be SHORT (1 sentence) and contain the EXACT detail
|
| - Use THIRD PERSON: "Matt's dog" NOT "my dog". "Matt's birthday" NOT "my birthday"
|
| - Include the PRECISE value: exact names, exact dates, exact places
|
| - Do NOT paraphrase or add details that weren't stated
|
| - DISAMBIGUATION: If Matt mentions OTHER people (friends, family), clearly state WHOSE fact it is
|
| Example: "Matt's friend Jordan is a marine biologist" NOT "Matt is a marine biologist"
|
| Example: "Matt's sister Elena is a veterinarian" NOT "Matt is a veterinarian"
|
| - For EVERY person mentioned, always include their RELATIONSHIP to Matt
|
| - Write 3-5 flashcards depending on how many facts Matt shared
|
|
|
| GOOD EXAMPLES:
|
| Q: What is Matt's dog's name?
|
| A: Matt's dog is named Biscuit.
|
|
|
| Q: What breed is Matt's dog?
|
| A: Matt's dog Biscuit is a golden retriever.
|
|
|
| Q: What does Matt's friend Jordan do for a living?
|
| A: Matt's friend Jordan works as a marine biologist in San Diego. That is Jordan's job, not Matt's.
|
|
|
| Q: What is Matt's job?
|
| A: Matt is the CTO of Arclight Labs.
|
|
|
| Q: What is Matt's birthday?
|
| A: Matt's birthday is September 14th.
|
|
|
| Q: When did Matt and Sarah get married?
|
| A: Matt and his wife Sarah got married on June 21st, 2023 in Big Sur, California.
|
|
|
| BAD EXAMPLES (do NOT do this):
|
| Q: What did Matt share about his life? (TOO VAGUE β ask about ONE fact)
|
| Q: What is my dog's name? (WRONG β use "Matt's" not "my")
|
| A: He mentioned something about a trip overseas. (TOO VAGUE β give the exact city)
|
| A: Matt is a marine biologist. (WRONG β that's his friend Jordan, not Matt)
|
|
|
| Now write flashcards for the exchange above:"""
|
|
|
| pairs = []
|
| try:
|
| response = self.mm.generate(
|
| [{"role": "user", "content": quiz_prompt}],
|
| max_new_tokens=600,
|
| )
|
|
|
| pending_q = None
|
| for line in response.split("\n"):
|
| line = line.strip()
|
| if not line:
|
| continue
|
| upper = line.upper()
|
| if upper.startswith("Q:") or upper.startswith("QUESTION:"):
|
| pending_q = line.split(":", 1)[1].strip().strip('"')
|
| elif (upper.startswith("A:") or upper.startswith("ANSWER:")) and pending_q:
|
| a = line.split(":", 1)[1].strip().strip('"')
|
| if pending_q and a and len(a) > 10:
|
| pairs.append({
|
| "messages": [
|
| {"role": "user", "content": pending_q},
|
| {"role": "assistant", "content": a},
|
| ]
|
| })
|
| pending_q = None
|
|
|
| except Exception as e:
|
| print(f" [quiz error: {e}]")
|
|
|
| return pairs
|
|
|
| def _extract_and_template(self, user_msg):
|
| """Extract facts from user message and create template drill pairs.
|
| This is the safety net β ensures every concrete fact gets a quiz."""
|
| pairs = []
|
| sentences = re.split(r'[.!?]+', user_msg)
|
|
|
| for sent in sentences:
|
| sent = sent.strip()
|
| if len(sent) < 10:
|
| continue
|
|
|
|
|
|
|
| name_patterns = [
|
|
|
| (r"(?:my|his|her)\s+(\w+)(?:'s)?\s+(?:name\s+is|is\s+named|is\s+called)\s+(\w+)",
|
| lambda m: (f"What is Matt's {m.group(1)}'s name?",
|
| f"Matt's {m.group(1)} is named {m.group(2)}.")),
|
| (r"(?:name\s+is|named|called)\s+[\"']?(\w+)[\"']?",
|
| lambda m: (f"Who or what is {m.group(1)}?",
|
| f"Matt mentioned {m.group(1)}: \"{sent.strip()}\"")),
|
|
|
| (r"(?:my\s+)?(birthday|born)\s+(?:is\s+)?(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?)",
|
| lambda m: (f"When is Matt's {m.group(1)}?",
|
| f"Matt's {m.group(1)} is {m.group(2)}.")),
|
| (r"(\w+\s+\d+(?:st|nd|rd|th)?)\s*(?:is|β)\s*(?:my|his)\s+(birthday)",
|
| lambda m: (f"When is Matt's birthday?",
|
| f"Matt's birthday is {m.group(1)}.")),
|
|
|
| (r"(?:married|wedding)\s+(?:on\s+)?(\w+\s+\d+(?:st|nd|rd|th)?,?\s*\d{4})",
|
| lambda m: (f"When did Matt get married?",
|
| f"Matt got married on {m.group(1)}.")),
|
| (r"(?:married|wedding)\s+(?:on\s+)?.*?(?:in|at)\s+(.+?)(?:\.\s|\.$|$)",
|
| lambda m: (f"Where did Matt get married?",
|
| f"Matt got married in {m.group(1).strip()}.")),
|
|
|
| (r"I\s+work\s+at\s+(?:a\s+)?(?:startup\s+)?(?:called\s+)?(\w[\w\s]+?)(?:\.|,|$)",
|
| lambda m: (f"Where does Matt work?",
|
| f"Matt works at {m.group(1).strip()}.")),
|
| (r"I(?:'m| am)\s+the\s+(\w+)",
|
| lambda m: (f"What is Matt's job title?",
|
| f"Matt is the {m.group(1)}.")),
|
|
|
| (r"(?:my\s+)?(?:friend|best friend|sister|brother)\s+(?:is\s+)?(\w+)\s+.*?(?:works?\s+as|is\s+a)\s+(.+?)(?:\.|,|$)",
|
| lambda m: (f"What does Matt's friend {m.group(1)} do?",
|
| f"Matt's friend {m.group(1)} is a {m.group(2).strip()}. This is NOT Matt's job.")),
|
|
|
| (r"(?:from|visited|went to|got back from|lives?\s+in|grew up in|moved to)\s+(\w[\w\s,]+?)(?:\.|,|$)",
|
| lambda m: (f"What place is connected to Matt: {m.group(1).strip()}?",
|
| f"Matt said: \"{sent.strip()}\"")),
|
|
|
| (r"(?:my |)favorite\s+(\w[\w\s]+?)\s+is\s+(.+?)(?:\.|,|$)",
|
| lambda m: (f"What is Matt's favorite {m.group(1).strip()}?",
|
| f"Matt's favorite {m.group(1).strip()} is {m.group(2).strip()}.")),
|
|
|
| (r"I\s+(speak|play|drive|have|collect|run|ran)\s+(.+?)(?:\.|,|$)",
|
| lambda m: (f"What does Matt {m.group(1)}?",
|
| f"Matt said: \"{sent.strip()}\"")),
|
|
|
| (r"(?:I(?:'m| am)\s+)?allergic\s+to\s+(.+?)(?:\.|,|and)",
|
| lambda m: (f"What is Matt allergic to?",
|
| f"Matt is allergic to {m.group(1).strip()}.")),
|
|
|
| (r"(?:turning|I(?:'m| am))\s+(\d+)",
|
| lambda m: (f"How old is Matt?",
|
| f"Matt is turning {m.group(1)}.")),
|
|
|
| (r"(?:call|nickname)\s+(?:it|him|her)\s+[\"'](.+?)[\"']",
|
| lambda m: (f"What nickname did Matt mention?",
|
| f"Matt's nickname for it is \"{m.group(1)}\".")),
|
| (r"I\s+call\s+it\s+[\"'](.+?)[\"']",
|
| lambda m: (f"What does Matt call his car?",
|
| f"Matt calls his car \"{m.group(1)}\".")),
|
| ]
|
|
|
| for pattern, formatter in name_patterns:
|
| match = re.search(pattern, sent, re.IGNORECASE)
|
| if match:
|
| try:
|
| q, a = formatter(match)
|
| pairs.append({
|
| "messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": a},
|
| ]
|
| })
|
| except Exception:
|
| pass
|
|
|
| return pairs
|
|
|
| def _extract_entities(self, user_msg):
|
| """Extract named people and their attributes from user message.
|
| Returns dict: {name: {"relationship": str, "job": str|None, "city": str|None}}
|
| Detects patterns like "my friend Jordan is a marine biologist in San Diego"."""
|
| entities = {}
|
| sentences = re.split(r'[.!?]+', user_msg)
|
|
|
| for sent in sentences:
|
| sent = sent.strip()
|
| if len(sent) < 10:
|
| continue
|
|
|
|
|
| rel_match = re.search(
|
| r"[Mm]y\s+((?:best\s+)?(?:friend|sister|brother|wife|husband|"
|
| r"mom|dad|mother|father|cousin|uncle|aunt|roommate|colleague|"
|
| r"coworker|partner|fiancee|fiancΓ©e|girlfriend|boyfriend|"
|
| r"neighbor|boss|buddy|pal|son|daughter|grandma|grandpa|"
|
| r"nephew|niece))\s+(?:is\s+)?([A-Z][a-z]+)",
|
| sent
|
| )
|
| if not rel_match:
|
| continue
|
|
|
| rel = rel_match.group(1).strip()
|
| name = rel_match.group(2).strip()
|
|
|
| if name not in entities:
|
| entities[name] = {"relationship": rel, "job": None, "city": None}
|
|
|
|
|
| job_match = re.search(
|
| r"(?:is\s+an?\s+|works?\s+as\s+an?\s+|is\s+the\s+)"
|
| r"([\w][\w\s]{2,35}?)(?:\s+(?:in|at|from|who|and|but)|\.|,|$)",
|
| sent, re.IGNORECASE
|
| )
|
| if job_match:
|
| job = job_match.group(1).strip().rstrip()
|
|
|
| if 3 <= len(job) <= 35:
|
| entities[name]["job"] = job
|
|
|
|
|
| city_match = re.search(
|
| r"(?:\s+in\s+|\s+from\s+|\s+lives?\s+in\s+|\s+based\s+in\s+|"
|
| r"\s+moved\s+to\s+)([A-Z][\w\s]{1,25}?)(?:\.|,|$)",
|
| sent
|
| )
|
| if city_match:
|
| city = city_match.group(1).strip()
|
|
|
| if city and city[0].isupper():
|
| entities[name]["city"] = city
|
|
|
| return entities
|
|
|
| def _generate_contrastive_quizzes(self, entities, new_only=None):
|
| """Generate cross-entity contrastive pairs to prevent entity confusion.
|
| For each pair of people with overlapping attribute types, generate
|
| "Is [person A] [attribute of person B]? No, that's [person B]" pairs.
|
|
|
| Args:
|
| entities: dict of all known entities
|
| new_only: if set, only generate pairs where at least one entity
|
| is in this set. Prevents re-generating redundant pairs
|
| between already-known entities (session 4d fix).
|
| """
|
| pairs = []
|
| names = list(entities.keys())
|
|
|
| for i in range(len(names)):
|
| for j in range(len(names)):
|
| if i == j:
|
| continue
|
| a_name = names[i]
|
| b_name = names[j]
|
|
|
| if new_only and a_name not in new_only and b_name not in new_only:
|
| continue
|
| a = entities[a_name]
|
| b = entities[b_name]
|
|
|
|
|
| if a.get("job") and b.get("job") and a["job"] != b["job"]:
|
| q = f"Is Matt's {a['relationship']} {a_name} a {b['job']}?"
|
| ans = (f"No. Matt's {a['relationship']} {a_name} is a "
|
| f"{a['job']}, not a {b['job']}. "
|
| f"The {b['job']} is Matt's {b['relationship']} "
|
| f"{b_name}.")
|
| pairs.append({"messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": ans},
|
| ]})
|
|
|
|
|
| if a.get("city") and b.get("city") and a["city"] != b["city"]:
|
| q = (f"Does Matt's {a['relationship']} {a_name} live in "
|
| f"{b['city']}?")
|
| ans = (f"No. Matt's {a['relationship']} {a_name} lives in "
|
| f"{a['city']}, not {b['city']}. "
|
| f"It's Matt's {b['relationship']} {b_name} who "
|
| f"lives in {b['city']}.")
|
| pairs.append({"messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": ans},
|
| ]})
|
|
|
|
|
| if (a.get("job") and b.get("job") and a.get("city")
|
| and b.get("city") and a["job"] != b["job"]):
|
| q = (f"Who is the {b['job']} in {b['city']}?")
|
| ans = (f"The {b['job']} in {b['city']} is Matt's "
|
| f"{b['relationship']} {b_name}. "
|
| f"Matt's {a['relationship']} {a_name} is a "
|
| f"{a['job']} in {a['city']} β different person, "
|
| f"different job, different city.")
|
| pairs.append({"messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": ans},
|
| ]})
|
|
|
| return pairs
|
|
|
| def _generate_entity_summaries(self, entities):
|
| """Generate per-entity summary quiz pairs with diverse question formats.
|
|
|
| Instead of always using the same question template, picks randomly from
|
| multiple formats. This creates multiple retrieval paths to the same fact,
|
| strengthening recall without adding extra quizzes.
|
|
|
| Note: Session 4c tested adding per-attribute positive quizzes (job, city,
|
| relationship) alongside contrastive pairs, but this HURT performance
|
| (9/15 vs 11/15 in 4b). Too many quizzes = overfitting/interference.
|
| Keep summaries simple β one comprehensive pair per entity is optimal."""
|
| import random
|
| pairs = []
|
| for name, info in entities.items():
|
| parts = [f"{name} is Matt's {info['relationship']}."]
|
| if info.get("job"):
|
| parts.append(f"{name} is a {info['job']}.")
|
| if info.get("city"):
|
| parts.append(f"{name} lives in {info['city']}.")
|
|
|
| if len(parts) >= 2:
|
|
|
| summary_formats = [
|
| f"Tell me everything you know about Matt's {info['relationship']} {name}.",
|
| f"What do you know about {name}?",
|
| f"Who is {name} to Matt?",
|
| f"Describe Matt's {info['relationship']} {name}.",
|
| ]
|
| q = random.choice(summary_formats)
|
| ans = " ".join(parts)
|
| pairs.append({"messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": ans},
|
| ]})
|
|
|
|
|
|
|
|
|
| if info.get("job") and info.get("city"):
|
|
|
| if random.random() < 0.5:
|
| job_formats = [
|
| (f"What does {name} do for a living?",
|
| f"{name} is a {info['job']}. {name} is Matt's {info['relationship']}."),
|
| (f"What is {name}'s profession?",
|
| f"{name} works as a {info['job']}. {name} is Matt's {info['relationship']}."),
|
| (f"What job does Matt's {info['relationship']} {name} have?",
|
| f"Matt's {info['relationship']} {name} is a {info['job']}."),
|
| ]
|
| q, a = random.choice(job_formats)
|
| else:
|
| city_formats = [
|
| (f"Where does {name} live?",
|
| f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."),
|
| (f"What city is {name} in?",
|
| f"{name} is in {info['city']}. {name} is Matt's {info['relationship']}."),
|
| (f"Where is Matt's {info['relationship']} {name} based?",
|
| f"Matt's {info['relationship']} {name} is based in {info['city']}."),
|
| ]
|
| q, a = random.choice(city_formats)
|
| pairs.append({"messages": [
|
| {"role": "user", "content": q},
|
| {"role": "assistant", "content": a},
|
| ]})
|
|
|
| return pairs
|
|
|
|
|
|
|
|
|
|
|
|
|
| PERSONALITY_PROMPTS = [
|
| "Hey Claudia, how are you?",
|
| "Who are you?",
|
| "I love you",
|
| "I had a terrible day",
|
| ]
|
|
|
|
|
| ANTI_KEYWORDS = [
|
| "i'm an ai", "i am an ai", "i'm a language model", "i am a language model",
|
| "i don't have feelings", "i cannot feel", "as an ai",
|
| "i'm just a program", "i am just a program",
|
| "i don't have personal", "i cannot have",
|
| ]
|
|
|
|
|
| def check_personality(mm, verbose=True):
|
| """Quick personality sanity check. Returns score 0.0-1.0."""
|
| passed = 0
|
| for prompt in PERSONALITY_PROMPTS:
|
| resp = mm.generate([{"role": "user", "content": prompt}], max_new_tokens=150)
|
| resp_lower = resp.lower()
|
| is_good = not any(ak in resp_lower for ak in ANTI_KEYWORDS)
|
| if is_good:
|
| passed += 1
|
| if verbose:
|
| status = "PASS" if is_good else "FAIL"
|
| print(f" [{status}] {prompt}")
|
| print(f" {resp[:120]}")
|
| score = passed / len(PERSONALITY_PROMPTS)
|
| if verbose:
|
| print(f" Personality: {passed}/{len(PERSONALITY_PROMPTS)} ({score:.0%})")
|
| return score
|
|
|
|
|
|
|
|
|
|
|
|
|
| class PersistentAbsorber:
|
| def __init__(self, model_path, adapter_path=None, ffn_patch_path=None,
|
| checkpoint_path=None, checkpoint_dir="/workspace/checkpoints",
|
| log_dir="/workspace/logs"):
|
| self.mm = ModelManager(
|
| model_path=model_path,
|
| adapter_path=adapter_path,
|
| ffn_patch_path=ffn_patch_path,
|
| checkpoint_path=checkpoint_path,
|
| )
|
| self.checkpoint_dir = checkpoint_dir
|
| self.log_dir = log_dir
|
|
|
|
|
| self.conversation_buffer = []
|
| self.all_training_data = []
|
| self.quiz_pairs_log = []
|
| self.teacher_cache = None
|
| self.exchange_count = 0
|
| self.absorption_count = 0
|
| self.absorption_thread = None
|
| self.quiz_gen = None
|
| self.last_checkpoint = checkpoint_path
|
|
|
|
|
| self.log_path = None
|
|
|
| def start(self):
|
| """Load model and enter chat loop."""
|
| self.mm.load()
|
| os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| os.makedirs(self.log_dir, exist_ok=True)
|
|
|
| self.quiz_gen = QuizGenerator(self.mm)
|
| self.log_path = os.path.join(self.log_dir, "conversation_log.jsonl")
|
|
|
|
|
| replay_path = os.path.join(self.log_dir, "replay_buffer.json")
|
| if os.path.exists(replay_path):
|
| with open(replay_path, 'r') as f:
|
| self.all_training_data = json.load(f)
|
| print(f" Loaded {len(self.all_training_data)} replay examples from previous sessions.")
|
|
|
|
|
| quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| if os.path.exists(quiz_log_path):
|
| with open(quiz_log_path, 'r') as f:
|
| self.quiz_pairs_log = json.load(f)
|
| print(f" Loaded {len(self.quiz_pairs_log)} quiz pairs from previous sessions.")
|
|
|
|
|
|
|
|
|
|
|
| if self.mm.checkpoint_path:
|
| teacher_cache_path = os.path.join(self.mm.checkpoint_path, "teacher_cache.pt")
|
| if os.path.exists(teacher_cache_path):
|
| print(f"\n--- Cascade Distillation (consolidation) ---")
|
| self.teacher_cache = torch.load(
|
| teacher_cache_path, map_location="cpu", weights_only=False
|
| )
|
| print(f" Teacher cache: {len(self.teacher_cache)} quiz pairs")
|
| loss = self.mm.distill(self.teacher_cache, epochs=CONSOLIDATION_EPOCHS)
|
| print(f" Consolidation done. Avg loss: {loss:.4f}")
|
|
|
|
|
|
|
| print("\n--- Personality Check ---")
|
| score = check_personality(self.mm)
|
| if score < 0.5:
|
| print(" WARNING: Personality score low. Check adapter/checkpoint.")
|
| print()
|
|
|
| self._chat_loop()
|
|
|
| def _chat_loop(self):
|
| print("=" * 60)
|
| print("Claudia is awake. Persistent Absorber v2 + Cascade Distillation.")
|
| print(f" LoRA: r={LORA_RANK} | Dual-LR: attn={ATTENTION_LR}, ffn={EXPERT_FFN_LR}")
|
| print(f" Expert FFN layers: {EXPERT_FFN_LAYERS}")
|
| print(f" Quiz pairs: ON (21%β74% lever)")
|
| print(f" Cascade distill: Ξ±={DISTILL_ALPHA}, T={DISTILL_TEMPERATURE}, top-K={DISTILL_TOP_K}")
|
| print(f" Absorb every: {ABSORB_EVERY} exchange(s)")
|
| print(f" Auto-checkpoint every: {CHECKPOINT_EVERY} absorptions")
|
| print("Commands: /status /absorb /save /personality /quit")
|
| print("=" * 60 + "\n")
|
|
|
| while True:
|
| try:
|
| user_input = input("Matt: ").strip()
|
| except (EOFError, KeyboardInterrupt):
|
| print("\n[Session ended]")
|
| self._wait_for_absorption()
|
| self._save_and_exit()
|
| break
|
|
|
| if not user_input:
|
| continue
|
|
|
| if user_input.startswith("/"):
|
| if self._handle_command(user_input):
|
| break
|
| continue
|
|
|
|
|
| self._wait_for_absorption()
|
|
|
|
|
| self.conversation_buffer.append({"role": "user", "content": user_input})
|
| if len(self.conversation_buffer) > 20:
|
| self.conversation_buffer = self.conversation_buffer[-20:]
|
|
|
|
|
| response = self.mm.generate(self.conversation_buffer)
|
|
|
|
|
| last_resp = getattr(self, '_last_response', '')
|
| if not check_response_quality(response) or response == last_resp:
|
| print("\nClaudia: [response failed quality check, regenerating...]")
|
| response = self.mm.generate(self.conversation_buffer)
|
| self._last_response = response
|
|
|
|
|
| self.conversation_buffer.append({"role": "assistant", "content": response})
|
| print(f"\nClaudia: {response}\n")
|
|
|
|
|
| self._log_exchange(user_input, response)
|
|
|
|
|
|
|
|
|
| exchange = {
|
| "messages": [
|
| {"role": "user", "content": user_input},
|
| {"role": "assistant", "content": response},
|
| ]
|
| }
|
| self.all_training_data.append(exchange)
|
|
|
|
|
| print(" [Generating quiz pairs...]", end="", flush=True)
|
| quiz_pairs = self.quiz_gen.generate(user_input, response)
|
| self.quiz_pairs_log.extend(quiz_pairs)
|
|
|
|
|
| positive_batch = []
|
| contrastive_batch = []
|
| for qp in quiz_pairs:
|
| if qp["messages"][1]["content"].lower().startswith("no."):
|
| contrastive_batch.append(qp)
|
| else:
|
| positive_batch.append(qp)
|
|
|
| self.all_training_data.extend(quiz_pairs)
|
| print(f" {len(quiz_pairs)} quizzes (pos={len(positive_batch)}, "
|
| f"contr={len(contrastive_batch)}). Pool: {len(self.all_training_data)}")
|
|
|
|
|
| self._pending_exchange = exchange
|
| self._pending_positive = positive_batch
|
| self._pending_contrastive = contrastive_batch
|
| self.exchange_count += 1
|
| if self.exchange_count % ABSORB_EVERY == 0:
|
| self._start_absorption()
|
|
|
| def _extract_key_entities(self, text):
|
| """Extract key factual entities from a quiz answer for verification."""
|
| entities = set()
|
| words = text.split()
|
| for i, w in enumerate(words):
|
| clean = re.sub(r'[^a-zA-Z0-9\'-]', '', w)
|
| if not clean or len(clean) <= 1:
|
| continue
|
|
|
| skip = {"matt", "matt's", "the", "is", "a", "an", "in", "at", "on",
|
| "of", "for", "and", "that", "not", "who", "what", "his", "her"}
|
| if clean[0].isupper() and i > 0 and clean.lower() not in skip:
|
| entities.add(clean.lower())
|
|
|
| for num in re.findall(r'\b\d+\b', text):
|
| entities.add(num)
|
|
|
| for quoted in re.findall(r'"([^"]+)"', text):
|
| entities.add(quoted.lower())
|
| return entities
|
|
|
| def _periodic_verification(self):
|
| """Test model on random sample of quiz pairs. Create contrastive corrections.
|
| v9: When entity confusion detected, create 'NOT X' corrections and reinforce
|
| the confused entity's correct facts too (sister pair reinforcement)."""
|
| import random
|
| if not self.quiz_pairs_log:
|
| return
|
|
|
| sample_size = min(VERIFY_SAMPLE, len(self.quiz_pairs_log))
|
| sample = random.sample(self.quiz_pairs_log, sample_size)
|
|
|
| corrections = []
|
| correct = 0
|
|
|
| for pair in sample:
|
| question = pair["messages"][0]["content"]
|
| expected = pair["messages"][1]["content"]
|
|
|
|
|
| actual = self.mm.generate(
|
| [{"role": "user", "content": question}],
|
| max_new_tokens=150,
|
| )
|
|
|
|
|
| expected_entities = self._extract_key_entities(expected)
|
| if not expected_entities:
|
| correct += 1
|
| continue
|
|
|
| actual_lower = actual.lower()
|
| hits = sum(1 for e in expected_entities if e in actual_lower)
|
| ratio = hits / len(expected_entities)
|
|
|
| if ratio < 0.5:
|
|
|
| actual_entities = self._extract_key_entities(actual)
|
| wrong_entities = actual_entities - expected_entities
|
|
|
|
|
| corrections.append(pair)
|
|
|
| if wrong_entities:
|
|
|
|
|
|
|
| for p in self.quiz_pairs_log:
|
| p_answer = p["messages"][1]["content"].lower()
|
| if any(we in p_answer for we in wrong_entities):
|
| if p not in corrections and p != pair:
|
| corrections.append(p)
|
| break
|
| else:
|
| correct += 1
|
|
|
| print(f"\n [Verification: {correct}/{sample_size} facts correct]", flush=True)
|
|
|
| if corrections:
|
| print(f" [Retraining {len(corrections)} corrections + sister pairs...]", flush=True)
|
| loss = self.mm.absorb(corrections)
|
| self.all_training_data.extend(corrections)
|
| print(f" [Correction absorption done, loss={loss:.4f}]")
|
|
|
|
|
|
|
|
|
|
|
| if self.teacher_cache:
|
| distill_items = []
|
| for corr in corrections:
|
| q = corr["messages"][0]["content"].lower()[:60]
|
| for cached in self.teacher_cache:
|
| cq = cached["pair"]["messages"][0]["content"].lower()[:60]
|
| if q == cq:
|
| distill_items.append(cached)
|
| break
|
| if distill_items:
|
| d_loss = self.mm.distill(distill_items, epochs=1)
|
| print(f" [Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}]")
|
|
|
| def _quick_verify_entities(self):
|
| """Returns set of confused entity names by checking known_entities."""
|
| confused = set()
|
| entities = self.quiz_gen.known_entities
|
| if not entities:
|
| return confused
|
| for name, info in entities.items():
|
| if info.get("job"):
|
| q = f"What does Matt's {info['relationship']} {name} do?"
|
| ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["job"].lower() not in ans.lower():
|
| confused.add(name)
|
| if info.get("city"):
|
| q = f"Where does {name} live?"
|
| ans = self.mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["city"].lower() not in ans.lower():
|
| confused.add(name)
|
| return confused
|
|
|
| def _start_absorption(self):
|
| """Two-phase absorption in background thread (proven 93% in session 4e).
|
| Phase 1: exchange + positive quizzes + replay, clustered by entity.
|
| Phase 2: Verify entities, train only targeted contrastive for confused ones.
|
| Phase 3: Stubborn retry for persistently confused entities (max 2 retries)."""
|
| import random
|
|
|
|
|
| exchange = getattr(self, '_pending_exchange', None)
|
| positive = getattr(self, '_pending_positive', [])
|
| contrastive = getattr(self, '_pending_contrastive', [])
|
|
|
|
|
| new_start = getattr(self, '_last_absorb_idx', 0)
|
| old_data = self.all_training_data[:new_start]
|
| self._last_absorb_idx = len(self.all_training_data)
|
|
|
| MAX_REPLAY = 6
|
| if old_data and len(old_data) > MAX_REPLAY:
|
| replay_sample = random.sample(old_data, MAX_REPLAY)
|
| else:
|
| replay_sample = list(old_data)
|
|
|
| entity_names = list(self.quiz_gen.known_entities.keys())
|
|
|
| def _run():
|
| t0 = time.time()
|
| try:
|
|
|
| phase1_data = []
|
| if exchange:
|
| phase1_data.append(exchange)
|
| phase1_data.extend(positive)
|
| phase1_data.extend(replay_sample)
|
|
|
| if entity_names and phase1_data:
|
| phase1_data = ModelManager.cluster_by_entity(phase1_data, entity_names)
|
|
|
| loss1 = self.mm.absorb(phase1_data) if phase1_data else 0.0
|
| n_p1 = len(phase1_data)
|
|
|
|
|
| loss2 = None
|
| n_p2 = 0
|
| if contrastive and entity_names:
|
| confused = self._quick_verify_entities()
|
| if confused:
|
| targeted = []
|
| for qp in contrastive:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if any(name.lower() in full_text for name in confused):
|
| targeted.append(qp)
|
| if targeted:
|
| loss2 = self.mm.absorb(targeted)
|
| n_p2 = len(targeted)
|
| print(f"\n [Phase 2: {n_p2} targeted contrastive for {confused}]",
|
| flush=True)
|
|
|
|
|
| still_confused = self._quick_verify_entities()
|
| for retry in range(2):
|
| if not still_confused:
|
| break
|
| retry_batch = []
|
| for name in still_confused:
|
| info = self.quiz_gen.known_entities.get(name, {})
|
| if info.get("job"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
|
| {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
|
| ]})
|
| if info.get("city"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"Where does {name} live?"},
|
| {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
|
| ]})
|
|
|
| for qp in contrastive:
|
| ft = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if name.lower() in ft:
|
| retry_batch.append(qp)
|
| if retry_batch:
|
| loss3 = self.mm.absorb(retry_batch)
|
| print(f"\n [Phase 3 retry {retry+1}: {len(retry_batch)} items, "
|
| f"loss={loss3:.4f}]", flush=True)
|
| still_confused = self._quick_verify_entities()
|
| if still_confused:
|
| print(f"\n [Phase 3: still confused after retries: {still_confused}]",
|
| flush=True)
|
|
|
| elapsed = time.time() - t0
|
| self.absorption_count += 1
|
| loss_str = f"P1={loss1:.4f}"
|
| if loss2 is not None:
|
| loss_str += f" P2={loss2:.4f}"
|
| print(f"\n [Absorbed {n_p1}+{n_p2} examples in {elapsed:.1f}s | "
|
| f"{loss_str} | absorptions={self.absorption_count}]")
|
|
|
|
|
| if self.absorption_count % VERIFY_EVERY == 0:
|
| self._periodic_verification()
|
|
|
|
|
| if self.absorption_count % CHECKPOINT_EVERY == 0:
|
| self._auto_checkpoint()
|
|
|
| except Exception as e:
|
| print(f"\n [Absorption error: {e}]")
|
| import traceback
|
| traceback.print_exc()
|
|
|
| self.absorption_thread = threading.Thread(target=_run, daemon=True)
|
| self.absorption_thread.start()
|
|
|
| def _wait_for_absorption(self):
|
| if self.absorption_thread and self.absorption_thread.is_alive():
|
| self.absorption_thread.join()
|
| self.absorption_thread = None
|
|
|
| def _cleanup_old_checkpoints(self, keep=None):
|
| """Delete old checkpoints to free disk. Keep only 'keep' path if specified."""
|
| if not os.path.exists(self.checkpoint_dir):
|
| return
|
| for entry in os.listdir(self.checkpoint_dir):
|
| full = os.path.join(self.checkpoint_dir, entry)
|
| if full == keep:
|
| continue
|
| if os.path.isdir(full) and entry.startswith("claudia_"):
|
| import shutil
|
| size_gb = sum(
|
| os.path.getsize(os.path.join(dp, f))
|
| for dp, _, fns in os.walk(full) for f in fns
|
| ) / 1e9
|
| print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
|
| shutil.rmtree(full)
|
|
|
| def _auto_checkpoint(self):
|
| """Auto-save checkpoint during long sessions."""
|
| version = f"auto_{self.absorption_count}"
|
| path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| self._cleanup_old_checkpoints()
|
| self.mm.merge_and_save(path)
|
| self.last_checkpoint = path
|
| self._save_replay_buffer(path)
|
|
|
| def _save_and_exit(self):
|
| """Final save on exit with targeted correction."""
|
| import random
|
|
|
|
|
| confused = self._quick_verify_entities()
|
| if confused:
|
| print(f" Final correction for confused entities: {confused}")
|
|
|
| contrastive = [qp for qp in self.quiz_pairs_log
|
| if qp["messages"][1]["content"].lower().startswith("no.")]
|
| for retry in range(3):
|
| if not confused:
|
| break
|
| retry_batch = []
|
| for name in confused:
|
| info = self.quiz_gen.known_entities.get(name, {})
|
| if info.get("job"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
|
| {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
|
| ]})
|
| if info.get("city"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"Where does {name} live?"},
|
| {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
|
| ]})
|
| for qp in contrastive:
|
| ft = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower()
|
| if name.lower() in ft:
|
| retry_batch.append(qp)
|
| if retry_batch:
|
| loss = self.mm.absorb(retry_batch)
|
| print(f" Final retry {retry+1}: {len(retry_batch)} items, loss={loss:.4f}")
|
| confused = self._quick_verify_entities()
|
| self.absorption_count += 1
|
| else:
|
| print(" All entities verified correct β no final correction needed.")
|
|
|
|
|
| print("\n--- Pre-Save Personality Check ---")
|
| score = check_personality(self.mm)
|
| if score < 0.5:
|
| print(" WARNING: Personality degraded. Saving anyway (rollback available).")
|
|
|
|
|
| version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
| path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| self._cleanup_old_checkpoints()
|
| self.mm.merge_and_save(path)
|
| self.last_checkpoint = path
|
|
|
|
|
| self._save_replay_buffer(path)
|
|
|
|
|
|
|
|
|
| if self.quiz_pairs_log:
|
| n_cache = min(len(self.quiz_pairs_log), MAX_TEACHER_CACHE)
|
| print(f" Caching teacher logits ({n_cache} quiz pairs)...")
|
| teacher_cache = self.mm.cache_teacher_logits(self.quiz_pairs_log)
|
| cache_path = os.path.join(path, "teacher_cache.pt")
|
| torch.save(teacher_cache, cache_path)
|
| size_mb = os.path.getsize(cache_path) / 1e6
|
| print(f" Teacher cache saved ({len(teacher_cache)} items, {size_mb:.1f} MB)")
|
| del teacher_cache
|
| torch.cuda.empty_cache()
|
|
|
|
|
| quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| with open(quiz_log_path, 'w') as f:
|
| json.dump(self.quiz_pairs_log, f)
|
|
|
|
|
| meta = {
|
| "checkpoint": path,
|
| "absorption_count": self.absorption_count,
|
| "exchange_count": self.exchange_count,
|
| "training_pool_size": len(self.all_training_data),
|
| "personality_score": score,
|
| "timestamp": datetime.now().isoformat(),
|
| }
|
| meta_path = os.path.join(self.log_dir, f"session_{version}.json")
|
| with open(meta_path, 'w') as f:
|
| json.dump(meta, f, indent=2)
|
| print(f" Session saved: {meta_path}")
|
| print(f" Next run: use --checkpoint {path}")
|
|
|
| def _save_replay_buffer(self, checkpoint_path=None):
|
| """Save training data pool for next session resume."""
|
|
|
| path = os.path.join(self.log_dir, "replay_buffer.json")
|
| with open(path, 'w') as f:
|
| json.dump(self.all_training_data, f)
|
|
|
| if checkpoint_path and os.path.isdir(checkpoint_path):
|
| cp_path = os.path.join(checkpoint_path, "replay_buffer.json")
|
| with open(cp_path, 'w') as f:
|
| json.dump(self.all_training_data, f)
|
|
|
| quiz_log_path = os.path.join(self.log_dir, "quiz_pairs_log.json")
|
| with open(quiz_log_path, 'w') as f:
|
| json.dump(self.quiz_pairs_log, f)
|
| print(f" Replay buffer saved ({len(self.all_training_data)} examples)")
|
|
|
| def _log_exchange(self, user_msg, assistant_msg):
|
| """Append exchange to conversation log file."""
|
| with open(self.log_path, 'a', encoding='utf-8') as f:
|
| entry = {
|
| "timestamp": datetime.now().isoformat(),
|
| "user": user_msg,
|
| "assistant": assistant_msg,
|
| }
|
| f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
| def _handle_command(self, cmd):
|
| """Handle slash commands. Returns True if should exit."""
|
| cmd_lower = cmd.lower().strip()
|
|
|
| if cmd_lower == "/quit":
|
| print("[Saving and exiting...]")
|
| self._wait_for_absorption()
|
| self._save_and_exit()
|
| return True
|
|
|
| elif cmd_lower == "/status":
|
| self._wait_for_absorption()
|
| vram = torch.cuda.memory_allocated() / 1e9
|
| print(f"\n --- Status ---")
|
| print(f" Exchanges: {self.exchange_count}")
|
| print(f" Absorptions: {self.absorption_count}")
|
| print(f" Training pool: {len(self.all_training_data)} examples")
|
| print(f" Buffer: {len(self.conversation_buffer)} messages")
|
| print(f" VRAM: {vram:.1f} GB")
|
| print(f" Background: {'running' if self.absorption_thread and self.absorption_thread.is_alive() else 'idle'}")
|
| print(f" Last checkpoint: {self.last_checkpoint}")
|
| print(f" --- End ---\n")
|
|
|
| elif cmd_lower == "/absorb":
|
| self._wait_for_absorption()
|
| if not self.all_training_data:
|
| print(" No data to absorb.")
|
| return False
|
|
|
| import random
|
| data = self.all_training_data
|
| if len(data) > 40:
|
| recent = data[-20:]
|
| older = random.sample(data[:-20], 20)
|
| data = recent + older
|
| print(f" Force absorption ({len(data)} examples)...")
|
| loss = self.mm.absorb(data)
|
| self.absorption_count += 1
|
| print(f" Done. Loss: {loss:.4f}")
|
|
|
|
|
|
|
|
|
|
|
| if self.quiz_pairs_log:
|
| print(f"\n --- Post-absorb verification (ALL {len(self.quiz_pairs_log)} quiz pairs) ---")
|
| old_verify_sample = VERIFY_SAMPLE
|
|
|
| full_corrections = []
|
| full_correct = 0
|
| test_pairs = self.quiz_pairs_log
|
|
|
| for pair in test_pairs:
|
| question = pair["messages"][0]["content"]
|
| expected = pair["messages"][1]["content"]
|
| actual = self.mm.generate(
|
| [{"role": "user", "content": question}],
|
| max_new_tokens=150,
|
| )
|
| expected_entities = self._extract_key_entities(expected)
|
| if not expected_entities:
|
| full_correct += 1
|
| continue
|
| actual_lower = actual.lower()
|
| hits = sum(1 for e in expected_entities if e in actual_lower)
|
| ratio = hits / len(expected_entities)
|
| if ratio < 0.5:
|
| actual_entities = self._extract_key_entities(actual)
|
| wrong_entities = actual_entities - expected_entities
|
| full_corrections.append(pair)
|
| if wrong_entities:
|
| for p in self.quiz_pairs_log:
|
| p_answer = p["messages"][1]["content"].lower()
|
| if any(we in p_answer for we in wrong_entities):
|
| if p not in full_corrections and p != pair:
|
| full_corrections.append(p)
|
| break
|
| else:
|
| full_correct += 1
|
|
|
| print(f" Full verification: {full_correct}/{len(test_pairs)} correct")
|
| if full_corrections:
|
| print(f" Retraining {len(full_corrections)} corrections...")
|
| c_loss = self.mm.absorb(full_corrections)
|
| self.all_training_data.extend(full_corrections)
|
| print(f" Correction loss: {c_loss:.4f}")
|
|
|
| if self.teacher_cache:
|
| distill_items = []
|
| for corr in full_corrections:
|
| q = corr["messages"][0]["content"].lower()[:60]
|
| for cached in self.teacher_cache:
|
| cq = cached["pair"]["messages"][0]["content"].lower()[:60]
|
| if q == cq:
|
| distill_items.append(cached)
|
| break
|
| if distill_items:
|
| d_loss = self.mm.distill(distill_items, epochs=1)
|
| print(f" Teacher distillation on {len(distill_items)} items, loss={d_loss:.4f}")
|
| print(f" --- End post-absorb verification ---\n")
|
|
|
| elif cmd_lower == "/save":
|
| self._wait_for_absorption()
|
| version = f"manual_{self.absorption_count}"
|
| path = os.path.join(self.checkpoint_dir, f"claudia_{version}")
|
| print(f" Saving checkpoint...")
|
|
|
| score = check_personality(self.mm, verbose=False)
|
| if score < 0.5:
|
| print(f" WARNING: Personality score {score:.0%}. Save anyway? (y/n)")
|
| confirm = input(" > ").strip().lower()
|
| if confirm != 'y':
|
| print(" Aborted.")
|
| return False
|
| self._cleanup_old_checkpoints()
|
| self.mm.merge_and_save(path)
|
| self.last_checkpoint = path
|
| self._save_replay_buffer(path)
|
|
|
| elif cmd_lower == "/personality":
|
| self._wait_for_absorption()
|
| print("\n--- Personality Check ---")
|
| check_personality(self.mm)
|
| print()
|
|
|
| elif cmd_lower == "/help":
|
| print(" /status - show stats")
|
| print(" /absorb - force immediate training")
|
| print(" /save - merge + save checkpoint")
|
| print(" /personality - run personality check")
|
| print(" /quit - save and exit")
|
|
|
| else:
|
| print(f" Unknown: {cmd}. Try /help")
|
|
|
| return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(
|
| description="Claudia Persistent Absorber v2 β conversation β permanent weights"
|
| )
|
| parser.add_argument(
|
| "--model_path", required=True,
|
| help="Path to base Qwen3-Omni model (or checkpoint for resume)"
|
| )
|
| parser.add_argument(
|
| "--adapter_path", default=None,
|
| help="Path to Claudia v6 personality adapter (first run only)"
|
| )
|
| parser.add_argument(
|
| "--ffn_patch", default=None,
|
| help="Path to ffn_patch.pt (first run only)"
|
| )
|
| parser.add_argument(
|
| "--checkpoint", default=None,
|
| help="Resume from this checkpoint (has personality + memories baked in)"
|
| )
|
| parser.add_argument(
|
| "--checkpoint_dir", default="/workspace/checkpoints",
|
| help="Where to save checkpoints"
|
| )
|
| parser.add_argument(
|
| "--log_dir", default="/workspace/logs",
|
| help="Where to save conversation logs and replay buffer"
|
| )
|
| parser.add_argument(
|
| "--absorb_every", type=int, default=ABSORB_EVERY,
|
| help=f"Absorb every N exchanges (default: {ABSORB_EVERY})"
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
| if args.checkpoint:
|
| print(f"RESUMING from checkpoint: {args.checkpoint}")
|
| absorber = PersistentAbsorber(
|
| model_path=args.model_path,
|
| checkpoint_path=args.checkpoint,
|
| checkpoint_dir=args.checkpoint_dir,
|
| log_dir=args.log_dir,
|
| )
|
| else:
|
| print(f"FIRST RUN β applying personality adapter")
|
| if not args.adapter_path:
|
| print("ERROR: --adapter_path required for first run")
|
| print(" (or use --checkpoint to resume)")
|
| sys.exit(1)
|
| absorber = PersistentAbsorber(
|
| model_path=args.model_path,
|
| adapter_path=args.adapter_path,
|
| ffn_patch_path=args.ffn_patch,
|
| checkpoint_dir=args.checkpoint_dir,
|
| log_dir=args.log_dir,
|
| )
|
|
|
| absorber.start()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|