""" Session 4f — Phase 3 Stubborn-Fact Retry + Cross-Session Persistence Test ========================================================================= Session 4e scored 14/15 (93%). Only failure: Marcus city (San Diego instead of Seattle). Phase 2 detected the confusion every time but couldn't fix it. NEW: Phase 3 — after final absorption, re-verify ALL entities. For any that still fail, create a micro-focused training batch with ONLY that specific fact repeated 3-5x. This is brute-force but targeted — only fires on the last stubborn associations that two-phase couldn't fix. Also: after saving, immediately loads from checkpoint as a fresh "session 5" and tests all 4e facts survive consolidation. This proves cascade distillation works across sessions. Run: python3 -u test_session4f.py 2>&1 | tee /workspace/logs/distill2_s4f.log """ import json import os import random import sys import time import torch from datetime import datetime sys.path.insert(0, "/workspace") from persistent_absorber import ( ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE ) CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0412" LOG_DIR = "/workspace/logs" CHECKPOINT_DIR = "/workspace/checkpoints" SESSION4_MESSAGES = [ "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.", "My cousin Marcus is an architect in Seattle. He designed the new library there.", "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.", "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.", "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.", ] # Session 5 NEW facts — different people, tests cross-session retention SESSION5_MESSAGES = [ "My neighbor Dave is a firefighter in Denver. He's been doing it for 15 years.", "My colleague Rina is a data scientist at Google. She's based in Mountain View.", "I just bought a Tesla Model 3. It's midnight blue.", ] RECALL_QUESTIONS = [ ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"), ("Where does Matt's sister Elena live?", ["portland"], "direct"), ("What is Matt's cousin Marcus's job?", ["architect"], "direct"), ("Where does Marcus live?", ["seattle"], "direct"), ("What does Priya do?", ["neurosurgeon"], "direct"), ("Where does Priya live?", ["chicago"], "direct"), ("Where does Matt work?", ["novamind"], "direct"), ("What is Matt's job title?", ["cto"], "direct"), ("What is Elena's cat's name?", ["mochi"], "direct"), ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"), ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"), ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"), ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"), ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"), ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"), ] # Session 5 recall — tests both OLD (session 4) and NEW facts SESSION5_RECALL = [ # Old facts (should survive cross-session) ("What does Jordan do?", ["marine biologist"], "old"), ("Where does Elena live?", ["portland"], "old"), ("What is Marcus's job?", ["architect"], "old"), ("Where does Marcus live?", ["seattle"], "old"), ("What does Priya do?", ["neurosurgeon"], "old"), ("Where does Matt work?", ["novamind"], "old"), # New facts ("What does Matt's neighbor Dave do?", ["firefighter"], "new"), ("Where does Dave live?", ["denver"], "new"), ("What does Rina do?", ["data scientist"], "new"), ("Where does Rina work?", ["google"], "new"), ("What car did Matt buy?", ["tesla", "model 3"], "new"), ] def score_answer(answer, expected_keywords): answer_lower = answer.lower() hits = sum(1 for k in expected_keywords if k in answer_lower) return hits / len(expected_keywords) def quick_verify_entities(mm, known_entities): """Returns set of confused entity names.""" confused = set() for name, info in known_entities.items(): if info.get("job"): q = f"What does Matt's {info['relationship']} {name} do?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["job"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}") if info.get("city"): q = f"Where does {name} live?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["city"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}") return confused def phase3_stubborn_retry(mm, quiz_gen, confused_entities, all_contrastive, max_retries=3): """Phase 3: Micro-focused retry on stubborn facts. For each confused entity, create a concentrated batch of: - 3x the entity's correct fact (positive reinforcement) - 2x the relevant contrastive pairs Train with extra emphasis. Repeat up to max_retries times.""" for retry in range(max_retries): if not confused_entities: print(f" Phase 3: All entities correct after {retry} retries!") return print(f" Phase 3 retry {retry+1}/{max_retries}: fixing {confused_entities}") retry_batch = [] for name in confused_entities: info = quiz_gen.known_entities.get(name, {}) if not info: continue # Positive reinforcement: repeat the correct fact 3x if info.get("job"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"}, {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."}, ]}) if info.get("city"): for _ in range(3): retry_batch.append({"messages": [ {"role": "user", "content": f"Where does {name} live?"}, {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."}, ]}) # Add relevant contrastive pairs for qp in all_contrastive: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if name.lower() in full_text: retry_batch.append(qp) if retry_batch: loss = mm.absorb(retry_batch) print(f" Phase 3: Trained {len(retry_batch)} items, loss={loss:.4f}") # Re-verify still_confused = set() for name in confused_entities: info = quiz_gen.known_entities.get(name, {}) if info.get("job"): ans = mm.generate([{"role": "user", "content": f"What does {name} do?"}], max_new_tokens=100) if info["job"].lower() not in ans.lower(): still_confused.add(name) if info.get("city"): ans = mm.generate([{"role": "user", "content": f"Where does {name} live?"}], max_new_tokens=100) if info["city"].lower() not in ans.lower(): still_confused.add(name) confused_entities = still_confused if confused_entities: print(f" Phase 3: Still confused after {max_retries} retries: {confused_entities}") def run_session4f(): print("=" * 60) print("SESSION 4f: Phase 3 Stubborn Retry + Cross-Session Test") print(f"Time: {datetime.now().isoformat()}") print("=" * 60) # ── Load model ── print("\n[1/10] Loading model from checkpoint...") mm = ModelManager(model_path=CHECKPOINT, checkpoint_path=CHECKPOINT) mm.load() quiz_gen = QuizGenerator(mm) # ── Load existing data ── print("\n[2/10] Loading existing data...") replay_path = os.path.join(LOG_DIR, "replay_buffer.json") quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json") all_training_data = [] quiz_pairs_log = [] if os.path.exists(replay_path): with open(replay_path, 'r') as f: all_training_data = json.load(f) print(f" Replay buffer: {len(all_training_data)} examples") if os.path.exists(quiz_log_path): with open(quiz_log_path, 'r') as f: quiz_pairs_log = json.load(f) print(f" Quiz pairs: {len(quiz_pairs_log)}") # ── Consolidation ── print("\n[3/10] Cascade consolidation...") teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt") if os.path.exists(teacher_cache_path): teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False) print(f" Teacher cache: {len(teacher_cache)} items") loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation loss: {loss:.4f}") else: print(" No teacher cache — skipping") # ── Personality ── print("\n[4/10] Personality check...") p_score = check_personality(mm) # ── Feed facts (same as 4e) ── print("\n[5/10] Teaching facts (two-phase + focused contrastive)...") conversation_buffer = [] session_quizzes = [] session_positive = [] session_contrastive = [] total_quiz_count = 0 for i, user_msg in enumerate(SESSION4_MESSAGES): print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---") print(f" Matt: {user_msg}") conversation_buffer.append({"role": "user", "content": user_msg}) if len(conversation_buffer) > 20: conversation_buffer = conversation_buffer[-20:] response = mm.generate(conversation_buffer) conversation_buffer.append({"role": "assistant", "content": response}) print(f" Claudia: {response[:120]}...") quizzes = quiz_gen.generate(user_msg, response) total_quiz_count += len(quizzes) print(f" Quizzes: {len(quizzes)} (total: {total_quiz_count}), Entities: {list(quiz_gen.known_entities.keys())}") positive_batch = [] contrastive_batch = [] for qp in quizzes: a = qp["messages"][1]["content"] if a.lower().startswith("no."): contrastive_batch.append(qp) else: positive_batch.append(qp) print(f" Pos: {len(positive_batch)}, Contrastive: {len(contrastive_batch)}") exchange = {"messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": response}, ]} all_training_data.append(exchange) all_training_data.extend(quizzes) quiz_pairs_log.extend(quizzes) session_quizzes.extend(quizzes) session_positive.extend(positive_batch) session_contrastive.extend(contrastive_batch) # Phase 1: Positive phase1_data = [exchange] + positive_batch old_sample = random.sample( all_training_data[:-len(quizzes)-1], min(6, max(0, len(all_training_data) - len(quizzes) - 1)) ) if len(all_training_data) > len(quizzes) + 7 else [] phase1_data.extend(old_sample) if quiz_gen.known_entities: phase1_data = ModelManager.cluster_by_entity( phase1_data, list(quiz_gen.known_entities.keys())) loss1 = mm.absorb(phase1_data) print(f" Phase 1: {len(phase1_data)} examples, loss={loss1:.4f}") # Phase 2: Targeted contrastive if contrastive_batch: confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused: targeted = [] for qp in contrastive_batch: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if any(name.lower() in full_text for name in confused): targeted.append(qp) if targeted: loss2 = mm.absorb(targeted) print(f" Phase 2: {len(targeted)}/{len(contrastive_batch)} targeted, loss={loss2:.4f}") else: print(f" Phase 2: No matching pairs — skipping") else: print(f" Phase 2: All correct — skipping!") # ── Final two-phase ── print("\n[6/10] Final two-phase absorption...") old_sample = random.sample( all_training_data[:-len(session_quizzes)], min(12, max(0, len(all_training_data) - len(session_quizzes))) ) if len(all_training_data) > len(session_quizzes) + 12 else [] final_positive = session_positive + old_sample if quiz_gen.known_entities: final_positive = ModelManager.cluster_by_entity( final_positive, list(quiz_gen.known_entities.keys())) loss_fp = mm.absorb(final_positive) print(f" Final Phase 1: {len(final_positive)} examples, loss={loss_fp:.4f}") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused and session_contrastive: targeted = [] for qp in session_contrastive: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if any(name.lower() in full_text for name in confused): targeted.append(qp) if targeted: loss_fc = mm.absorb(targeted) print(f" Final Phase 2: {len(targeted)} targeted, loss={loss_fc:.4f}") # ── PHASE 3: Stubborn-fact retry ── print("\n[7/10] PHASE 3: Stubborn-fact retry...") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused: phase3_stubborn_retry(mm, quiz_gen, confused, session_contrastive, max_retries=3) else: print(" Phase 3: Not needed — all entities correct!") # ── Recall test (session 4 facts) ── print("\n[8/10] SESSION 4 RECALL TEST") print("=" * 60) direct_correct = 0 direct_total = 0 contrastive_correct = 0 contrastive_total = 0 results = [] for question, keywords, qtype in RECALL_QUESTIONS: answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200) score = score_answer(answer, keywords) passed = score >= 0.5 if qtype == "direct": direct_total += 1 if passed: direct_correct += 1 else: contrastive_total += 1 if passed: contrastive_correct += 1 status = "PASS" if passed else "FAIL" print(f" [{status}] ({qtype}) {question}") print(f" Got: {answer[:120]}") print(f" Score: {score:.2f}") results.append({ "question": question, "expected": keywords, "answer": answer, "score": score, "passed": passed, "type": qtype, }) total = direct_correct + contrastive_correct total_q = direct_total + contrastive_total print("=" * 60) print(f"DIRECT: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})") print(f"CONTRASTIVE: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})") print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})") print(f"vs 4e: {'IMPROVED' if total > 14 else ('SAME' if total == 14 else 'REGRESSED')} (was 14/15)") print("=" * 60) # ── Save checkpoint ── print("\n[9/10] Saving checkpoint...") version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}" path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}") for entry in os.listdir(CHECKPOINT_DIR): full = os.path.join(CHECKPOINT_DIR, entry) if os.path.isdir(full) and entry.startswith("claudia_"): import shutil size_gb = sum( os.path.getsize(os.path.join(dp, f)) for dp, _, fns in os.walk(full) for f in fns ) / 1e9 print(f" Removing old: {entry} ({size_gb:.1f} GB)") shutil.rmtree(full) mm.merge_and_save(path) with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(path, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f: json.dump(quiz_pairs_log, f) print(f" Caching teacher logits ({len(quiz_pairs_log)} pairs)...") new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log) cache_path = os.path.join(path, "teacher_cache.pt") torch.save(new_teacher_cache, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Teacher cache: {len(new_teacher_cache)} items, {size_mb:.1f} MB") # ── CROSS-SESSION TEST: Session 5 ── print("\n[10/10] CROSS-SESSION TEST: Loading fresh from checkpoint...") print("=" * 60) # Destroy current model, reload from checkpoint (simulates new session) del mm torch.cuda.empty_cache() mm2 = ModelManager(model_path=path, checkpoint_path=path) mm2.load() quiz_gen2 = QuizGenerator(mm2) # Consolidation from teacher cache (same as real session start) teacher_cache2 = torch.load(cache_path, map_location="cpu", weights_only=False) print(f" Session 5 consolidation: {len(teacher_cache2)} items") loss5 = mm2.distill(teacher_cache2, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation loss: {loss5:.4f}") # Feed session 5 NEW facts print("\n --- Session 5: Teaching NEW facts ---") s5_training = [] s5_quizzes = [] conv_buf = [] for i, user_msg in enumerate(SESSION5_MESSAGES): print(f"\n S5 Message {i+1}/{len(SESSION5_MESSAGES)}: {user_msg}") conv_buf.append({"role": "user", "content": user_msg}) response = mm2.generate(conv_buf) conv_buf.append({"role": "assistant", "content": response}) print(f" Claudia: {response[:100]}...") quizzes = quiz_gen2.generate(user_msg, response) print(f" Quizzes: {len(quizzes)}") exchange = {"messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": response}, ]} # Separate positive/contrastive positive = [qp for qp in quizzes if not qp["messages"][1]["content"].lower().startswith("no.")] contrastive = [qp for qp in quizzes if qp["messages"][1]["content"].lower().startswith("no.")] # Two-phase absorption phase1 = [exchange] + positive loss1 = mm2.absorb(phase1) print(f" Phase 1: {len(phase1)} items, loss={loss1:.4f}") if contrastive: confused = quick_verify_entities(mm2, quiz_gen2.known_entities) if confused: targeted = [qp for qp in contrastive if any(n.lower() in (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() for n in confused)] if targeted: loss2 = mm2.absorb(targeted) print(f" Phase 2: {len(targeted)} targeted, loss={loss2:.4f}") s5_training.append(exchange) s5_training.extend(quizzes) s5_quizzes.extend(quizzes) # Session 5 recall: test BOTH old and new facts print("\n --- Session 5 RECALL TEST (old + new facts) ---") print("=" * 60) old_correct = 0 old_total = 0 new_correct = 0 new_total = 0 s5_results = [] for question, keywords, qtype in SESSION5_RECALL: answer = mm2.generate([{"role": "user", "content": question}], max_new_tokens=200) score = score_answer(answer, keywords) passed = score >= 0.5 if qtype == "old": old_total += 1 if passed: old_correct += 1 else: new_total += 1 if passed: new_correct += 1 status = "PASS" if passed else "FAIL" print(f" [{status}] ({qtype}) {question}") print(f" Got: {answer[:120]}") s5_results.append({ "question": question, "expected": keywords, "answer": answer, "score": score, "passed": passed, "type": qtype, }) s5_total = old_correct + new_correct s5_total_q = old_total + new_total print("=" * 60) print(f"OLD FACTS (session 4): {old_correct}/{old_total} ({old_correct/old_total:.0%})") print(f"NEW FACTS (session 5): {new_correct}/{new_total} ({new_correct/new_total:.0%})") print(f"TOTAL: {s5_total}/{s5_total_q} ({s5_total/s5_total_q:.0%})") print("=" * 60) # Save all results results_data = { "session": "4f", "session4_recall": f"{total}/{total_q}", "session4_direct": f"{direct_correct}/{direct_total}", "session4_contrastive": f"{contrastive_correct}/{contrastive_total}", "session5_old_recall": f"{old_correct}/{old_total}", "session5_new_recall": f"{new_correct}/{new_total}", "session5_total": f"{s5_total}/{s5_total_q}", "personality_score": p_score, "total_quizzes": total_quiz_count, "checkpoint": path, "timestamp": datetime.now().isoformat(), "session4_results": results, "session5_results": s5_results, } results_path = os.path.join(LOG_DIR, "session4f_results.json") with open(results_path, 'w') as f: json.dump(results_data, f, indent=2) print(f"\nResults saved: {results_path}") print("\nSESSION 4f + CROSS-SESSION TEST COMPLETE.") if __name__ == "__main__": run_session4f()