""" Session 4e — Fixed Contrastive Focus + Better Phase 2 Targeting ================================================================ Session 4d scored 11/15 (same as 4b). Two bugs fixed: 1. CONTRASTIVE DUPLICATION FIX: _generate_contrastive_quizzes now only generates pairs involving at least one NEW entity. Previously re-generated all pairs between existing entities every message, making 50% of quizzes contrastive. 2. PHASE 2 TARGETING FIX: When checking which contrastive pairs are relevant to confused entities, now checks BOTH question AND answer text (not just question). "Is Jordan an architect?" is relevant to Marcus confusion even though "Marcus" only appears in the answer. Expected improvement: More positive quizzes (better direct recall) + more effective contrastive correction (Phase 2 doesn't skip relevant pairs). Run: python3 -u test_session4e.py 2>&1 | tee /workspace/logs/distill2_s4e.log """ import json import os import random import sys import time import torch from datetime import datetime sys.path.insert(0, "/workspace") from persistent_absorber import ( ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE ) CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0327" LOG_DIR = "/workspace/logs" CHECKPOINT_DIR = "/workspace/checkpoints" SESSION4_MESSAGES = [ "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.", "My cousin Marcus is an architect in Seattle. He designed the new library there.", "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.", "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.", "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.", ] RECALL_QUESTIONS = [ ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"), ("Where does Matt's sister Elena live?", ["portland"], "direct"), ("What is Matt's cousin Marcus's job?", ["architect"], "direct"), ("Where does Marcus live?", ["seattle"], "direct"), ("What does Priya do?", ["neurosurgeon"], "direct"), ("Where does Priya live?", ["chicago"], "direct"), ("Where does Matt work?", ["novamind"], "direct"), ("What is Matt's job title?", ["cto"], "direct"), ("What is Elena's cat's name?", ["mochi"], "direct"), ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"), ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"), ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"), ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"), ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"), ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"), ] def score_answer(answer, expected_keywords): answer_lower = answer.lower() hits = sum(1 for k in expected_keywords if k in answer_lower) return hits / len(expected_keywords) def quick_verify_entities(mm, known_entities): """Quick verification: test model on key entity facts. Returns set of entity names that the model gets WRONG.""" confused = set() for name, info in known_entities.items(): if info.get("job"): q = f"What does Matt's {info['relationship']} {name} do?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["job"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}") if info.get("city"): q = f"Where does {name} live?" ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100) if info["city"].lower() not in ans.lower(): confused.add(name) print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}") return confused def run_session4e(): print("=" * 60) print("SESSION 4e: Fixed Contrastive Focus + Phase 2 Targeting") print(f"Time: {datetime.now().isoformat()}") print("=" * 60) # ── Load model ── print("\n[1/8] Loading model from checkpoint...") mm = ModelManager( model_path=CHECKPOINT, checkpoint_path=CHECKPOINT, ) mm.load() quiz_gen = QuizGenerator(mm) # ── Load existing data ── print("\n[2/8] Loading existing data...") replay_path = os.path.join(LOG_DIR, "replay_buffer.json") quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json") all_training_data = [] quiz_pairs_log = [] if os.path.exists(replay_path): with open(replay_path, 'r') as f: all_training_data = json.load(f) print(f" Replay buffer: {len(all_training_data)} examples") if os.path.exists(quiz_log_path): with open(quiz_log_path, 'r') as f: quiz_pairs_log = json.load(f) print(f" Quiz pairs: {len(quiz_pairs_log)}") # ── Consolidation ── print("\n[3/8] Cascade consolidation from teacher cache...") teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt") if os.path.exists(teacher_cache_path): teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False) print(f" Teacher cache: {len(teacher_cache)} items") loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS) print(f" Consolidation loss: {loss:.4f}") else: print(" No teacher cache found — skipping") # ── Personality check ── print("\n[4/8] Personality check...") p_score = check_personality(mm) # ── Feed facts with TWO-PHASE absorption (FIXED) ── print("\n[5/8] Teaching facts (fixed contrastive focus + phase 2 targeting)...") conversation_buffer = [] session_quizzes = [] session_positive = [] session_contrastive = [] contrastive_count = 0 positive_count = 0 total_quiz_count = 0 for i, user_msg in enumerate(SESSION4_MESSAGES): print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---") print(f" Matt: {user_msg}") conversation_buffer.append({"role": "user", "content": user_msg}) if len(conversation_buffer) > 20: conversation_buffer = conversation_buffer[-20:] response = mm.generate(conversation_buffer) conversation_buffer.append({"role": "assistant", "content": response}) print(f" Claudia: {response[:150]}...") # Generate quizzes (now with focused contrastive generation) quizzes = quiz_gen.generate(user_msg, response) total_quiz_count += len(quizzes) print(f" Quizzes generated: {len(quizzes)} (session total: {total_quiz_count})") print(f" Known entities: {list(quiz_gen.known_entities.keys())}") # Separate positive from contrastive positive_batch = [] contrastive_batch = [] for qp in quizzes: a = qp["messages"][1]["content"] is_contrastive = a.lower().startswith("no.") if is_contrastive: contrastive_batch.append(qp) contrastive_count += 1 else: positive_batch.append(qp) positive_count += 1 # Print quiz breakdown print(f" Positive: {len(positive_batch)}, Contrastive: {len(contrastive_batch)}") for qi, qp in enumerate(quizzes[:8]): q = qp["messages"][0]["content"] a = qp["messages"][1]["content"] tag = " [C]" if a.lower().startswith("no.") else "" print(f" Q{qi+1}: {q}{tag}") print(f" A{qi+1}: {a[:120]}") exchange = {"messages": [ {"role": "user", "content": user_msg}, {"role": "assistant", "content": response}, ]} all_training_data.append(exchange) all_training_data.extend(quizzes) quiz_pairs_log.extend(quizzes) session_quizzes.extend(quizzes) session_positive.extend(positive_batch) session_contrastive.extend(contrastive_batch) # ── TWO-PHASE ABSORPTION ── # Phase 1: Positive facts + exchange + replay phase1_data = [exchange] + positive_batch old_sample = random.sample( all_training_data[:-len(quizzes)-1], min(6, max(0, len(all_training_data) - len(quizzes) - 1)) ) if len(all_training_data) > len(quizzes) + 7 else [] phase1_data.extend(old_sample) if quiz_gen.known_entities: phase1_data = ModelManager.cluster_by_entity( phase1_data, list(quiz_gen.known_entities.keys())) t0 = time.time() loss1 = mm.absorb(phase1_data) elapsed1 = time.time() - t0 print(f" Phase 1 (positive): {len(phase1_data)} examples, {elapsed1:.1f}s, loss={loss1:.4f}") # Phase 2: Targeted contrastive correction (FIXED targeting) if contrastive_batch: print(f" Phase 2: Verifying entities before contrastive correction...") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused: # FIX: Check BOTH question AND answer text for confused entity names # "Is Jordan an architect?" has "Marcus" in the answer, not question targeted = [] for qp in contrastive_batch: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if any(name.lower() in full_text for name in confused): targeted.append(qp) if targeted: t0 = time.time() loss2 = mm.absorb(targeted) elapsed2 = time.time() - t0 print(f" Phase 2 (contrastive): {len(targeted)}/{len(contrastive_batch)} targeted, " f"{elapsed2:.1f}s, loss={loss2:.4f}") else: print(f" Phase 2: No contrastive pairs matched confused entities") else: print(f" Phase 2: All entities correct — skipping contrastive!") print(f"\n Session 4e totals:") print(f" Total quizzes: {total_quiz_count}") print(f" Positive: {positive_count}") print(f" Contrastive: {contrastive_count}") print(f" Ratio: {positive_count/(positive_count+contrastive_count):.0%} positive") print(f" Known entities: {list(quiz_gen.known_entities.keys())}") # ── Final two-phase absorption ── print("\n[6/8] Final two-phase absorption pass...") old_sample = random.sample( all_training_data[:-len(session_quizzes)], min(12, max(0, len(all_training_data) - len(session_quizzes))) ) if len(all_training_data) > len(session_quizzes) + 12 else [] final_positive = session_positive + old_sample if quiz_gen.known_entities: final_positive = ModelManager.cluster_by_entity( final_positive, list(quiz_gen.known_entities.keys())) loss_fp = mm.absorb(final_positive) print(f" Final Phase 1 (positive): {len(final_positive)} examples, loss={loss_fp:.4f}") # Final Phase 2 with FIXED targeting print(f" Final Phase 2: Verifying entities...") confused = quick_verify_entities(mm, quiz_gen.known_entities) if confused and session_contrastive: targeted = [] for qp in session_contrastive: full_text = (qp["messages"][0]["content"] + " " + qp["messages"][1]["content"]).lower() if any(name.lower() in full_text for name in confused): targeted.append(qp) if targeted: loss_fc = mm.absorb(targeted) print(f" Final Phase 2 (contrastive): {len(targeted)} targeted, loss={loss_fc:.4f}") else: print(f" Final Phase 2: No matching contrastive pairs") else: status = "no confusion detected" if not confused else "no contrastive data" print(f" Final Phase 2: Skipped ({status})") # ── Recall test ── print("\n[7/8] RECALL TEST") print("=" * 60) direct_correct = 0 direct_total = 0 contrastive_correct = 0 contrastive_total = 0 results = [] for question, keywords, qtype in RECALL_QUESTIONS: answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200) score = score_answer(answer, keywords) passed = score >= 0.5 if qtype == "direct": direct_total += 1 if passed: direct_correct += 1 else: contrastive_total += 1 if passed: contrastive_correct += 1 status = "PASS" if passed else "FAIL" print(f" [{status}] ({qtype}) {question}") print(f" Expected: {keywords}") print(f" Got: {answer[:150]}") print(f" Score: {score:.2f}") print() results.append({ "question": question, "expected": keywords, "answer": answer, "score": score, "passed": passed, "type": qtype, }) print("=" * 60) print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})") print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})") total = direct_correct + contrastive_correct total_q = direct_total + contrastive_total print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})") comparison = "IMPROVED" if total > 11 else ("SAME" if total == 11 else "REGRESSED") print(f"vs SESSION 4b/4d: {comparison} (was 11/15 = 73%)") print(f"DIRECT vs 4d: {'UP' if direct_correct > 5 else ('SAME' if direct_correct == 5 else 'DOWN')} (was 5/9)") print("=" * 60) # ── Save ── print("\n[8/8] Saving checkpoint...") version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}" path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}") for entry in os.listdir(CHECKPOINT_DIR): full = os.path.join(CHECKPOINT_DIR, entry) if os.path.isdir(full) and entry.startswith("claudia_"): import shutil size_gb = sum( os.path.getsize(os.path.join(dp, f)) for dp, _, fns in os.walk(full) for f in fns ) / 1e9 print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)") shutil.rmtree(full) mm.merge_and_save(path) with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(path, "replay_buffer.json"), 'w') as f: json.dump(all_training_data, f) with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f: json.dump(quiz_pairs_log, f) print(f" Replay buffer: {len(all_training_data)} examples") print(f" Quiz pairs log: {len(quiz_pairs_log)}") print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...") new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log) cache_path = os.path.join(path, "teacher_cache.pt") torch.save(new_teacher_cache, cache_path) size_mb = os.path.getsize(cache_path) / 1e6 print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)") results_data = { "session": "4e", "fixes": [ "contrastive focus: only generate pairs involving new entities", "phase 2 targeting: check both Q and A text for confused entities", ], "checkpoint": path, "direct_recall": f"{direct_correct}/{direct_total}", "contrastive_recall": f"{contrastive_correct}/{contrastive_total}", "total_recall": f"{total}/{total_q}", "vs_session4b": comparison, "personality_score": p_score, "total_quizzes": total_quiz_count, "positive_quizzes": positive_count, "contrastive_quizzes": contrastive_count, "total_quiz_pairs": len(quiz_pairs_log), "total_replay": len(all_training_data), "known_entities": {k: v for k, v in quiz_gen.known_entities.items()}, "timestamp": datetime.now().isoformat(), "results": results, } results_path = os.path.join(LOG_DIR, "session4e_results.json") with open(results_path, 'w') as f: json.dump(results_data, f, indent=2) print(f" Results saved: {results_path}") print(f"\n Next run: --checkpoint {path}") print("\nSESSION 4e COMPLETE.") if __name__ == "__main__": run_session4e()