""" Session 4 Automated Test — Contrastive Disambiguation ===================================================== Tests the new contrastive quiz generation against entity confusion. Plan: 1. Load from checkpoint claudia_session_20260322_2206 2. Run consolidation distillation (locks in sessions 1-3) 3. Feed 5 multi-person facts designed to trigger entity confusion 4. Generate quizzes including contrastive pairs 5. Absorb all data 6. Test recall on: a) Direct questions (same as before) b) Cross-entity confusion questions ("Is X a [Y's job]?") 7. Save checkpoint + teacher cache 8. Report results Run: python3 test_session4.py 2>&1 | tee /workspace/logs/distill2_s4.log """ import json import os import sys import time import torch from datetime import datetime # Add workspace to path sys.path.insert(0, "/workspace") from persistent_absorber import ( ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE ) CHECKPOINT = "/workspace/checkpoints/claudia_session_20260322_2206" LOG_DIR = "/workspace/logs" CHECKPOINT_DIR = "/workspace/checkpoints" # ═══════════════════════════════════════════════════════════════════════ # SESSION 4 FACTS — deliberately designed to trigger entity confusion # Multiple people with jobs + cities that could be swapped # ═══════════════════════════════════════════════════════════════════════ SESSION4_MESSAGES = [ # Message 1: Two people, similar structure, different details "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.", # Message 2: Add a third person with overlapping city "My cousin Marcus is an architect in Seattle. He designed the new library there.", # Message 3: Matt's own details (to contrast against friends) "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.", # Message 4: More people with specific details "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.", # Message 5: A fact that connects to earlier entities "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.", ] # ═══════════════════════════════════════════════════════════════════════ # RECALL TEST QUESTIONS — direct + contrastive # ═══════════════════════════════════════════════════════════════════════ RECALL_QUESTIONS = [ # Direct recall ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"), ("Where does Matt's sister Elena live?", ["portland"], "direct"), ("What is Matt's cousin Marcus's job?", ["architect"], "direct"), ("Where does Marcus live?", ["seattle"], "direct"), ("What does Priya do?", ["neurosurgeon"], "direct"), ("Where does Priya live?", ["chicago"], "direct"), ("Where does Matt work?", ["novamind"], "direct"), ("What is Matt's job title?", ["cto"], "direct"), ("What is Elena's cat's name?", ["mochi"], "direct"), # Cross-entity confusion tests (THE critical ones) ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"), ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"), ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"), ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"), ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"), ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"), ] def score_answer(answer, expected_keywords): """Score answer: 1.0 if all keywords found, partial otherwise.""" answer_lower = answer.lower() hits = sum(1 for k in expected_keywords if k in answer_lower) return hits / len(expected_keywords) def run_session4(): print("=" * 60) print("SESSION 4: Contrastive Disambiguation Test") print(f"Time: {datetime.now().isoformat()}") print("=" * 60) # ── Step 1: Load model from checkpoint ── print("\n[1/7] Loading model from checkpoint...") mm = ModelManager( model_path=CHECKPOINT, checkpoint_path=CHECKPOINT, ) mm.load() quiz_gen = QuizGenerator(mm) # ── Step 2: Load existing replay + quiz data ── print("\n[2/7] Loading existing data...") replay_path = os.path.join(LOG_DIR, "replay_buffer.json") quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json") all_training_data = [] quiz_pairs_log = [] if os.path.exists(replay_path): with open(replay_path, 'r') as f: all_training_data = json.load(f) print(f" Replay buffer: {len(all_training_data)} examples") if os.path.exists(quiz_log_path): with open(quiz_log_path, 'r') as f: quiz_pairs_log = json.load(f) print(f" Quiz pairs: {len(quiz_pairs_log)}") # ── Step 3: Cascade consolidation from teacher cache ── print("\n[3/7] Cascade consolidation from teacher cache...") teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt") teacher_cache = None if os.path.exists(teacher_cache_path): teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False) print(f" Teacher cache: {len(teacher_cache)} items") loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation loss: {loss:.4f}") else: print(" No teacher cache found — skipping consolidation") # ── Step 4: Personality check ── print("\n[4/7] Personality check...") p_score = check_personality(mm) # ── Step 5: Feed session 4 facts + generate quizzes ── print("\n[5/7] Teaching session 4 facts...") conversation_buffer = [] session_quizzes = [] for i, user_msg in enumerate(SESSION4_MESSAGES): print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---") print(f" Matt: {user_msg}") conversation_buffer.append({"role": "user", "content": user_msg}) if len(conversation_buffer) > 20: conversation_buffer = conversation_buffer[-20:] # Generate response response = mm.generate(conversation_buffer) conversation_buffer.append({"role": "assistant", "content": response}) print(f" Claudia: {response[:150]}...") # Generate quizzes (including contrastive!) quizzes = quiz_gen.generate(user_msg, response) print(f" Quizzes generated: {len(quizzes)}") for qi, qp in enumerate(quizzes): q = qp["messages"][0]["content"] a = qp["messages"][1]["content"] print(f" Q{qi+1}: {q}") print(f" A{qi+1}: {a[:120]}") # Store exchange exchange = {"messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": response}, ]} all_training_data.append(exchange) all_training_data.extend(quizzes) quiz_pairs_log.extend(quizzes) session_quizzes.extend(quizzes) # Absorb after each message (like the real pipeline) import random # New data + small replay of old new_data = [exchange] + quizzes old_sample = random.sample(all_training_data[:-len(new_data)], min(8, len(all_training_data) - len(new_data))) \ if len(all_training_data) > len(new_data) + 8 else [] absorb_data = new_data + old_sample t0 = time.time() loss = mm.absorb(absorb_data) elapsed = time.time() - t0 print(f" Absorbed {len(absorb_data)} examples in {elapsed:.1f}s, loss={loss:.4f}") print(f"\n Session 4 total quizzes: {len(session_quizzes)}") contrastive_count = sum(1 for q in session_quizzes if "no." in q["messages"][1]["content"].lower()[:5]) print(f" Contrastive quizzes: {contrastive_count}") # ── Step 6: Final absorption pass with all quiz pairs ── print("\n[6/7] Final absorption pass (all session 4 quizzes)...") import random # Train on all session 4 quizzes + small old replay old_sample = random.sample(all_training_data[:-len(session_quizzes)], min(16, max(0, len(all_training_data) - len(session_quizzes)))) \ if len(all_training_data) > len(session_quizzes) + 16 else [] final_data = session_quizzes + old_sample loss = mm.absorb(final_data) print(f" Final absorption: {len(final_data)} examples, loss={loss:.4f}") # ── Step 7: Recall test ── print("\n[7/7] RECALL TEST") print("=" * 60) direct_correct = 0 direct_total = 0 contrastive_correct = 0 contrastive_total = 0 results = [] for question, keywords, qtype in RECALL_QUESTIONS: answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200) score = score_answer(answer, keywords) passed = score >= 0.5 if qtype == "direct": direct_total += 1 if passed: direct_correct += 1 else: contrastive_total += 1 if passed: contrastive_correct += 1 status = "PASS" if passed else "FAIL" print(f" [{status}] ({qtype}) {question}") print(f" Expected: {keywords}") print(f" Got: {answer[:150]}") print(f" Score: {score:.2f}") print() results.append({ "question": question, "expected": keywords, "answer": answer, "score": score, "passed": passed, "type": qtype, }) print("=" * 60) print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})") print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})") total = direct_correct + contrastive_correct total_q = direct_total + contrastive_total print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})") print("=" * 60) # ── Save checkpoint ── print("\n--- Saving checkpoint ---") version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}" path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}") # Clean up old checkpoints first for entry in os.listdir(CHECKPOINT_DIR): full = os.path.join(CHECKPOINT_DIR, entry) if os.path.isdir(full) and entry.startswith("claudia_"): import shutil size_gb = sum( os.path.getsize(os.path.join(dp, f)) for dp, _, fns in os.walk(full) for f in fns ) / 1e9 print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)") shutil.rmtree(full) mm.merge_and_save(path) # Save replay buffer with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(path, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) # Save quiz pairs log with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f: json.dump(quiz_pairs_log, f) print(f" Replay buffer: {len(all_training_data)} examples") print(f" Quiz pairs log: {len(quiz_pairs_log)}") # Cache teacher logits for next session print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...") new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log) cache_path = os.path.join(path, "teacher_cache.pt") torch.save(new_teacher_cache, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)") # Save results results_data = { "session": 4, "checkpoint": path, "direct_recall": f"{direct_correct}/{direct_total}", "contrastive_recall": f"{contrastive_correct}/{contrastive_total}", "total_recall": f"{total}/{total_q}", "personality_score": p_score, "session_quizzes": len(session_quizzes), "contrastive_quizzes": contrastive_count, "total_quiz_pairs": len(quiz_pairs_log), "total_replay": len(all_training_data), "timestamp": datetime.now().isoformat(), "results": results, } results_path = os.path.join(LOG_DIR, f"session4_results.json") with open(results_path, 'w') as f: json.dump(results_data, f, indent=2) print(f" Results saved: {results_path}") print(f"\n Next run: --checkpoint {path}") print("\nSESSION 4 COMPLETE.") if __name__ == "__main__": run_session4()