""" Session 4d — Two-Phase Absorption + Quiz Diversity + Entity Clustering ====================================================================== Tests 3 improvements over session 4b (best: 11/15 = 73%, contrastive 6/6 = 100%): 1. TWO-PHASE ABSORPTION: Positive facts first, then targeted contrastive correction only on entities that fail verification. Prevents contrastive gradients from fighting positive ones. 2. QUIZ DIVERSITY: Multiple question formats per fact (randomly varied). Creates multiple retrieval paths to same fact, strengthening direct recall without increasing total quiz count. 3. ENTITY CLUSTERING: Group training data by entity. Train all Jordan facts together, then all Elena facts, etc. Reduces cross-entity contamination. Key constraint: Keep total quizzes ~35-40 (session 4c proved >40 hurts). Run: python3 test_session4d.py 2>&1 | tee /workspace/logs/distill2_s4d.log """ import json import os import random import sys import time import torch from datetime import datetime sys.path.insert(0, "/workspace") from persistent_absorber import ( ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE ) CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0049" LOG_DIR = "/workspace/logs" CHECKPOINT_DIR = "/workspace/checkpoints" # Same facts as session 4/4b/4c SESSION4_MESSAGES = [ "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.", "My cousin Marcus is an architect in Seattle. He designed the new library there.", "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.", "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.", "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.", ] RECALL_QUESTIONS = [ ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"), ("Where does Matt's sister Elena live?", ["portland"], "direct"), ("What is Matt's cousin Marcus's job?", ["architect"], "direct"), ("Where does Marcus live?", ["seattle"], "direct"), ("What does Priya do?", ["neurosurgeon"], "direct"), ("Where does Priya live?", ["chicago"], "direct"), ("Where does Matt work?", ["novamind"], "direct"), ("What is Matt's job title?", ["cto"], "direct"), ("What is Elena's cat's name?", ["mochi"], "direct"), ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"), ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"), ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"), ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"), ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"), ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"), ] def score_answer(answer, expected_keywords): answer_lower = answer.lower() hits = sum(1 for k in expected_keywords if k in answer_lower) return hits / len(expected_keywords) def quick_verify_entities(mm, known_entities): """Quick verification: test model on key entity facts. Returns set of entity names that the model gets WRONG.""" confused = set() for name, info in known_entities.items(): if info.get("job"): q = f"What does Matt's {info['relationship']} {name} do?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["job"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}") if info.get("city"): q = f"Where does {name} live?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["city"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}") return confused def run_session4d(): print("=" * 60) print("SESSION 4d: Two-Phase + Diversity + Clustering") print(f"Time: {datetime.now().isoformat()}") print("=" * 60) # ── Load model ── print("\n[1/8] Loading model from checkpoint...") mm = ModelManager( model_path=CHECKPOINT, checkpoint_path=CHECKPOINT, ) mm.load() quiz_gen = QuizGenerator(mm) # ── Load existing data ── print("\n[2/8] Loading existing data...") replay_path = os.path.join(LOG_DIR, "replay_buffer.json") quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json") all_training_data = [] quiz_pairs_log = [] if os.path.exists(replay_path): with open(replay_path, 'r') as f: all_training_data = json.load(f) print(f" Replay buffer: {len(all_training_data)} examples") if os.path.exists(quiz_log_path): with open(quiz_log_path, 'r') as f: quiz_pairs_log = json.load(f) print(f" Quiz pairs: {len(quiz_pairs_log)}") # ── Consolidation ── print("\n[3/8] Cascade consolidation from teacher cache...") teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt") if os.path.exists(teacher_cache_path): teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False) print(f" Teacher cache: {len(teacher_cache)} items") loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation loss: {loss:.4f}") else: print(" No teacher cache found — skipping") # ── Personality check ── print("\n[4/8] Personality check...") p_score = check_personality(mm) # ── Feed facts with TWO-PHASE absorption ── print("\n[5/8] Teaching facts with TWO-PHASE absorption + diversity + clustering...") conversation_buffer = [] session_quizzes = [] session_positive = [] # Positive quizzes for this session session_contrastive = [] # Contrastive quizzes for this session contrastive_count = 0 total_quiz_count = 0 for i, user_msg in enumerate(SESSION4_MESSAGES): print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---") print(f" Matt: {user_msg}") conversation_buffer.append({"role": "user", "content": user_msg}) if len(conversation_buffer) > 20: conversation_buffer = conversation_buffer[-20:] response = mm.generate(conversation_buffer) conversation_buffer.append({"role": "assistant", "content": response}) print(f" Claudia: {response[:150]}...") # Generate quizzes (with diversity from updated QuizGenerator) quizzes = quiz_gen.generate(user_msg, response) total_quiz_count += len(quizzes) print(f" Quizzes generated: {len(quizzes)} (session total: {total_quiz_count})") print(f" Known entities: {list(quiz_gen.known_entities.keys())}") # Separate positive from contrastive positive_batch = [] contrastive_batch = [] for qp in quizzes: q = qp["messages"][0]["content"] a = qp["messages"][1]["content"] is_contrastive = a.lower().startswith("no.") if is_contrastive: contrastive_batch.append(qp) contrastive_count += 1 tag = " [CONTRASTIVE]" else: positive_batch.append(qp) tag = "" # Print first few if len(positive_batch) + len(contrastive_batch) <= 6: print(f" Q: {q}{tag}") print(f" A: {a[:120]}") # Build the exchange item exchange = {"messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": response}, ]} # Add to tracking lists all_training_data.append(exchange) all_training_data.extend(quizzes) quiz_pairs_log.extend(quizzes) session_quizzes.extend(quizzes) session_positive.extend(positive_batch) session_contrastive.extend(contrastive_batch) # ── TWO-PHASE ABSORPTION (per-message) ── # Phase 1: Positive facts + exchange + small replay phase1_data = [exchange] + positive_batch old_sample = random.sample( all_training_data[:-len(quizzes)-1], min(6, max(0, len(all_training_data) - len(quizzes) - 1)) ) if len(all_training_data) > len(quizzes) + 7 else [] phase1_data.extend(old_sample) # ENTITY CLUSTERING: order phase1 data by entity if quiz_gen.known_entities: phase1_data = ModelManager.cluster_by_entity( phase1_data, list(quiz_gen.known_entities.keys())) t0 = time.time() loss1 = mm.absorb(phase1_data) elapsed1 = time.time() - t0 print(f" Phase 1 (positive): {len(phase1_data)} examples, {elapsed1:.1f}s, loss={loss1:.4f}") # Phase 2: Targeted contrastive correction if contrastive_batch: # Quick verify which entities are confused print(f" Phase 2: Verifying entities before contrastive correction...") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused: # Only train contrastive pairs that mention confused entities targeted = [] for qp in contrastive_batch: q_lower = qp["messages"][0]["content"].lower() if any(name.lower() in q_lower for name in confused): targeted.append(qp) if targeted: t0 = time.time() loss2 = mm.absorb(targeted) elapsed2 = time.time() - t0 print(f" Phase 2 (contrastive): {len(targeted)}/{len(contrastive_batch)} targeted, " f"{elapsed2:.1f}s, loss={loss2:.4f}") else: print(f" Phase 2: No contrastive pairs for confused entities — skipping") else: print(f" Phase 2: All entities correct — skipping contrastive training!") print(f"\n Session 4d totals:") print(f" Quizzes: {total_quiz_count}") print(f" Positive: {len(session_positive)}") print(f" Contrastive: {contrastive_count}") print(f" Known entities: {list(quiz_gen.known_entities.keys())}") # ── Final two-phase absorption ── print("\n[6/8] Final two-phase absorption pass...") # Phase 1: All positive quizzes + replay old_sample = random.sample( all_training_data[:-len(session_quizzes)], min(12, max(0, len(all_training_data) - len(session_quizzes))) ) if len(all_training_data) > len(session_quizzes) + 12 else [] final_positive = session_positive + old_sample if quiz_gen.known_entities: final_positive = ModelManager.cluster_by_entity( final_positive, list(quiz_gen.known_entities.keys())) loss_fp = mm.absorb(final_positive) print(f" Final Phase 1 (positive): {len(final_positive)} examples, loss={loss_fp:.4f}") # Phase 2: Verify then targeted contrastive print(f" Final Phase 2: Verifying entities...") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused and session_contrastive: targeted = [] for qp in session_contrastive: q_lower = qp["messages"][0]["content"].lower() if any(name.lower() in q_lower for name in confused): targeted.append(qp) if targeted: loss_fc = mm.absorb(targeted) print(f" Final Phase 2 (contrastive): {len(targeted)} targeted, loss={loss_fc:.4f}") else: print(f" Final Phase 2: No contrastive pairs needed") else: status = "no confusion detected" if not confused else "no contrastive data" print(f" Final Phase 2: Skipped ({status})") # ── Recall test ── print("\n[7/8] RECALL TEST") print("=" * 60) direct_correct = 0 direct_total = 0 contrastive_correct = 0 contrastive_total = 0 results = [] for question, keywords, qtype in RECALL_QUESTIONS: answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200) score = score_answer(answer, keywords) passed = score >= 0.5 if qtype == "direct": direct_total += 1 if passed: direct_correct += 1 else: contrastive_total += 1 if passed: contrastive_correct += 1 status = "PASS" if passed else "FAIL" print(f" [{status}] ({qtype}) {question}") print(f" Expected: {keywords}") print(f" Got: {answer[:150]}") print(f" Score: {score:.2f}") print() results.append({ "question": question, "expected": keywords, "answer": answer, "score": score, "passed": passed, "type": qtype, }) print("=" * 60) print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})") print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})") total = direct_correct + contrastive_correct total_q = direct_total + contrastive_total print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})") s4b_comparison = "IMPROVED" if total > 11 else ("SAME" if total == 11 else "REGRESSED") print(f"vs SESSION 4b: {s4b_comparison} (was 11/15 = 73%)") print("=" * 60) # ── Save ── print("\n[8/8] Saving checkpoint...") version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}" path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}") for entry in os.listdir(CHECKPOINT_DIR): full = os.path.join(CHECKPOINT_DIR, entry) if os.path.isdir(full) and entry.startswith("claudia_"): import shutil size_gb = sum( os.path.getsize(os.path.join(dp, f)) for dp, _, fns in os.walk(full) for f in fns ) / 1e9 print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)") shutil.rmtree(full) mm.merge_and_save(path) with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(path, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f: json.dump(quiz_pairs_log, f) print(f" Replay buffer: {len(all_training_data)} examples") print(f" Quiz pairs log: {len(quiz_pairs_log)}") print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...") new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log) cache_path = os.path.join(path, "teacher_cache.pt") torch.save(new_teacher_cache, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)") results_data = { "session": "4d", "improvements": [ "two-phase absorption (positive first, targeted contrastive second)", "quiz format diversity (random question formats per fact)", "entity-clustered absorption ordering", ], "checkpoint": path, "direct_recall": f"{direct_correct}/{direct_total}", "contrastive_recall": f"{contrastive_correct}/{contrastive_total}", "total_recall": f"{total}/{total_q}", "vs_session4b": s4b_comparison, "personality_score": p_score, "session_quizzes": len(session_quizzes), "positive_quizzes": len(session_positive), "contrastive_quizzes": contrastive_count, "total_quiz_pairs": len(quiz_pairs_log), "total_replay": len(all_training_data), "known_entities": {k: v for k, v in quiz_gen.known_entities.items()}, "timestamp": datetime.now().isoformat(), "results": results, } results_path = os.path.join(LOG_DIR, "session4d_results.json") with open(results_path, 'w') as f: json.dump(results_data, f, indent=2) print(f" Results saved: {results_path}") print(f"\n Next run: --checkpoint {path}") print("\nSESSION 4d COMPLETE.") if __name__ == "__main__": run_session4d()