| """
|
| Session 4f — Phase 3 Stubborn-Fact Retry + Cross-Session Persistence Test
|
| =========================================================================
|
| Session 4e scored 14/15 (93%). Only failure: Marcus city (San Diego instead
|
| of Seattle). Phase 2 detected the confusion every time but couldn't fix it.
|
|
|
| NEW: Phase 3 — after final absorption, re-verify ALL entities. For any that
|
| still fail, create a micro-focused training batch with ONLY that specific
|
| fact repeated 3-5x. This is brute-force but targeted — only fires on the
|
| last stubborn associations that two-phase couldn't fix.
|
|
|
| Also: after saving, immediately loads from checkpoint as a fresh "session 5"
|
| and tests all 4e facts survive consolidation. This proves cascade distillation
|
| works across sessions.
|
|
|
| Run: python3 -u test_session4f.py 2>&1 | tee /workspace/logs/distill2_s4f.log
|
| """
|
|
|
| import json
|
| import os
|
| import random
|
| import sys
|
| import time
|
| import torch
|
| from datetime import datetime
|
|
|
| sys.path.insert(0, "/workspace")
|
| from persistent_absorber import (
|
| ModelManager, QuizGenerator, check_personality,
|
| CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE
|
| )
|
|
|
| CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0412"
|
| LOG_DIR = "/workspace/logs"
|
| CHECKPOINT_DIR = "/workspace/checkpoints"
|
|
|
| SESSION4_MESSAGES = [
|
| "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.",
|
| "My cousin Marcus is an architect in Seattle. He designed the new library there.",
|
| "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.",
|
| "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.",
|
| "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.",
|
| ]
|
|
|
|
|
| SESSION5_MESSAGES = [
|
| "My neighbor Dave is a firefighter in Denver. He's been doing it for 15 years.",
|
| "My colleague Rina is a data scientist at Google. She's based in Mountain View.",
|
| "I just bought a Tesla Model 3. It's midnight blue.",
|
| ]
|
|
|
| RECALL_QUESTIONS = [
|
| ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"),
|
| ("Where does Matt's sister Elena live?", ["portland"], "direct"),
|
| ("What is Matt's cousin Marcus's job?", ["architect"], "direct"),
|
| ("Where does Marcus live?", ["seattle"], "direct"),
|
| ("What does Priya do?", ["neurosurgeon"], "direct"),
|
| ("Where does Priya live?", ["chicago"], "direct"),
|
| ("Where does Matt work?", ["novamind"], "direct"),
|
| ("What is Matt's job title?", ["cto"], "direct"),
|
| ("What is Elena's cat's name?", ["mochi"], "direct"),
|
| ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"),
|
| ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"),
|
| ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"),
|
| ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"),
|
| ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"),
|
| ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"),
|
| ]
|
|
|
|
|
| SESSION5_RECALL = [
|
|
|
| ("What does Jordan do?", ["marine biologist"], "old"),
|
| ("Where does Elena live?", ["portland"], "old"),
|
| ("What is Marcus's job?", ["architect"], "old"),
|
| ("Where does Marcus live?", ["seattle"], "old"),
|
| ("What does Priya do?", ["neurosurgeon"], "old"),
|
| ("Where does Matt work?", ["novamind"], "old"),
|
|
|
| ("What does Matt's neighbor Dave do?", ["firefighter"], "new"),
|
| ("Where does Dave live?", ["denver"], "new"),
|
| ("What does Rina do?", ["data scientist"], "new"),
|
| ("Where does Rina work?", ["google"], "new"),
|
| ("What car did Matt buy?", ["tesla", "model 3"], "new"),
|
| ]
|
|
|
|
|
| def score_answer(answer, expected_keywords):
|
| answer_lower = answer.lower()
|
| hits = sum(1 for k in expected_keywords if k in answer_lower)
|
| return hits / len(expected_keywords)
|
|
|
|
|
| def quick_verify_entities(mm, known_entities):
|
| """Returns set of confused entity names."""
|
| confused = set()
|
| for name, info in known_entities.items():
|
| if info.get("job"):
|
| q = f"What does Matt's {info['relationship']} {name} do?"
|
| ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["job"].lower() not in ans.lower():
|
| confused.add(name)
|
| print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}")
|
| if info.get("city"):
|
| q = f"Where does {name} live?"
|
| ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["city"].lower() not in ans.lower():
|
| confused.add(name)
|
| print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}")
|
| return confused
|
|
|
|
|
| def phase3_stubborn_retry(mm, quiz_gen, confused_entities, all_contrastive, max_retries=3):
|
| """Phase 3: Micro-focused retry on stubborn facts.
|
| For each confused entity, create a concentrated batch of:
|
| - 3x the entity's correct fact (positive reinforcement)
|
| - 2x the relevant contrastive pairs
|
| Train with extra emphasis. Repeat up to max_retries times."""
|
|
|
| for retry in range(max_retries):
|
| if not confused_entities:
|
| print(f" Phase 3: All entities correct after {retry} retries!")
|
| return
|
|
|
| print(f" Phase 3 retry {retry+1}/{max_retries}: fixing {confused_entities}")
|
|
|
| retry_batch = []
|
| for name in confused_entities:
|
| info = quiz_gen.known_entities.get(name, {})
|
| if not info:
|
| continue
|
|
|
|
|
| if info.get("job"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
|
| {"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
|
| ]})
|
| if info.get("city"):
|
| for _ in range(3):
|
| retry_batch.append({"messages": [
|
| {"role": "user", "content": f"Where does {name} live?"},
|
| {"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
|
| ]})
|
|
|
|
|
| for qp in all_contrastive:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if name.lower() in full_text:
|
| retry_batch.append(qp)
|
|
|
| if retry_batch:
|
| loss = mm.absorb(retry_batch)
|
| print(f" Phase 3: Trained {len(retry_batch)} items, loss={loss:.4f}")
|
|
|
|
|
| still_confused = set()
|
| for name in confused_entities:
|
| info = quiz_gen.known_entities.get(name, {})
|
| if info.get("job"):
|
| ans = mm.generate([{"role": "user", "content": f"What does {name} do?"}], max_new_tokens=100)
|
| if info["job"].lower() not in ans.lower():
|
| still_confused.add(name)
|
| if info.get("city"):
|
| ans = mm.generate([{"role": "user", "content": f"Where does {name} live?"}], max_new_tokens=100)
|
| if info["city"].lower() not in ans.lower():
|
| still_confused.add(name)
|
|
|
| confused_entities = still_confused
|
|
|
| if confused_entities:
|
| print(f" Phase 3: Still confused after {max_retries} retries: {confused_entities}")
|
|
|
|
|
| def run_session4f():
|
| print("=" * 60)
|
| print("SESSION 4f: Phase 3 Stubborn Retry + Cross-Session Test")
|
| print(f"Time: {datetime.now().isoformat()}")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[1/10] Loading model from checkpoint...")
|
| mm = ModelManager(model_path=CHECKPOINT, checkpoint_path=CHECKPOINT)
|
| mm.load()
|
| quiz_gen = QuizGenerator(mm)
|
|
|
|
|
| print("\n[2/10] Loading existing data...")
|
| replay_path = os.path.join(LOG_DIR, "replay_buffer.json")
|
| quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json")
|
| all_training_data = []
|
| quiz_pairs_log = []
|
|
|
| if os.path.exists(replay_path):
|
| with open(replay_path, 'r') as f:
|
| all_training_data = json.load(f)
|
| print(f" Replay buffer: {len(all_training_data)} examples")
|
| if os.path.exists(quiz_log_path):
|
| with open(quiz_log_path, 'r') as f:
|
| quiz_pairs_log = json.load(f)
|
| print(f" Quiz pairs: {len(quiz_pairs_log)}")
|
|
|
|
|
| print("\n[3/10] Cascade consolidation...")
|
| teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt")
|
| if os.path.exists(teacher_cache_path):
|
| teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False)
|
| print(f" Teacher cache: {len(teacher_cache)} items")
|
| loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS)
|
| print(f" Consolidation loss: {loss:.4f}")
|
| else:
|
| print(" No teacher cache — skipping")
|
|
|
|
|
| print("\n[4/10] Personality check...")
|
| p_score = check_personality(mm)
|
|
|
|
|
| print("\n[5/10] Teaching facts (two-phase + focused contrastive)...")
|
| conversation_buffer = []
|
| session_quizzes = []
|
| session_positive = []
|
| session_contrastive = []
|
| total_quiz_count = 0
|
|
|
| for i, user_msg in enumerate(SESSION4_MESSAGES):
|
| print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---")
|
| print(f" Matt: {user_msg}")
|
|
|
| conversation_buffer.append({"role": "user", "content": user_msg})
|
| if len(conversation_buffer) > 20:
|
| conversation_buffer = conversation_buffer[-20:]
|
|
|
| response = mm.generate(conversation_buffer)
|
| conversation_buffer.append({"role": "assistant", "content": response})
|
| print(f" Claudia: {response[:120]}...")
|
|
|
| quizzes = quiz_gen.generate(user_msg, response)
|
| total_quiz_count += len(quizzes)
|
| print(f" Quizzes: {len(quizzes)} (total: {total_quiz_count}), Entities: {list(quiz_gen.known_entities.keys())}")
|
|
|
| positive_batch = []
|
| contrastive_batch = []
|
| for qp in quizzes:
|
| a = qp["messages"][1]["content"]
|
| if a.lower().startswith("no."):
|
| contrastive_batch.append(qp)
|
| else:
|
| positive_batch.append(qp)
|
|
|
| print(f" Pos: {len(positive_batch)}, Contrastive: {len(contrastive_batch)}")
|
|
|
| exchange = {"messages": [
|
| {"role": "user", "content": user_msg},
|
| {"role": "assistant", "content": response},
|
| ]}
|
|
|
| all_training_data.append(exchange)
|
| all_training_data.extend(quizzes)
|
| quiz_pairs_log.extend(quizzes)
|
| session_quizzes.extend(quizzes)
|
| session_positive.extend(positive_batch)
|
| session_contrastive.extend(contrastive_batch)
|
|
|
|
|
| phase1_data = [exchange] + positive_batch
|
| old_sample = random.sample(
|
| all_training_data[:-len(quizzes)-1],
|
| min(6, max(0, len(all_training_data) - len(quizzes) - 1))
|
| ) if len(all_training_data) > len(quizzes) + 7 else []
|
| phase1_data.extend(old_sample)
|
|
|
| if quiz_gen.known_entities:
|
| phase1_data = ModelManager.cluster_by_entity(
|
| phase1_data, list(quiz_gen.known_entities.keys()))
|
|
|
| loss1 = mm.absorb(phase1_data)
|
| print(f" Phase 1: {len(phase1_data)} examples, loss={loss1:.4f}")
|
|
|
|
|
| if contrastive_batch:
|
| confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
| if confused:
|
| targeted = []
|
| for qp in contrastive_batch:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if any(name.lower() in full_text for name in confused):
|
| targeted.append(qp)
|
| if targeted:
|
| loss2 = mm.absorb(targeted)
|
| print(f" Phase 2: {len(targeted)}/{len(contrastive_batch)} targeted, loss={loss2:.4f}")
|
| else:
|
| print(f" Phase 2: No matching pairs — skipping")
|
| else:
|
| print(f" Phase 2: All correct — skipping!")
|
|
|
|
|
| print("\n[6/10] Final two-phase absorption...")
|
| old_sample = random.sample(
|
| all_training_data[:-len(session_quizzes)],
|
| min(12, max(0, len(all_training_data) - len(session_quizzes)))
|
| ) if len(all_training_data) > len(session_quizzes) + 12 else []
|
|
|
| final_positive = session_positive + old_sample
|
| if quiz_gen.known_entities:
|
| final_positive = ModelManager.cluster_by_entity(
|
| final_positive, list(quiz_gen.known_entities.keys()))
|
| loss_fp = mm.absorb(final_positive)
|
| print(f" Final Phase 1: {len(final_positive)} examples, loss={loss_fp:.4f}")
|
|
|
| confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
| if confused and session_contrastive:
|
| targeted = []
|
| for qp in session_contrastive:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if any(name.lower() in full_text for name in confused):
|
| targeted.append(qp)
|
| if targeted:
|
| loss_fc = mm.absorb(targeted)
|
| print(f" Final Phase 2: {len(targeted)} targeted, loss={loss_fc:.4f}")
|
|
|
|
|
| print("\n[7/10] PHASE 3: Stubborn-fact retry...")
|
| confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
| if confused:
|
| phase3_stubborn_retry(mm, quiz_gen, confused, session_contrastive, max_retries=3)
|
| else:
|
| print(" Phase 3: Not needed — all entities correct!")
|
|
|
|
|
| print("\n[8/10] SESSION 4 RECALL TEST")
|
| print("=" * 60)
|
|
|
| direct_correct = 0
|
| direct_total = 0
|
| contrastive_correct = 0
|
| contrastive_total = 0
|
| results = []
|
|
|
| for question, keywords, qtype in RECALL_QUESTIONS:
|
| answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
|
| score = score_answer(answer, keywords)
|
| passed = score >= 0.5
|
|
|
| if qtype == "direct":
|
| direct_total += 1
|
| if passed: direct_correct += 1
|
| else:
|
| contrastive_total += 1
|
| if passed: contrastive_correct += 1
|
|
|
| status = "PASS" if passed else "FAIL"
|
| print(f" [{status}] ({qtype}) {question}")
|
| print(f" Got: {answer[:120]}")
|
| print(f" Score: {score:.2f}")
|
|
|
| results.append({
|
| "question": question, "expected": keywords,
|
| "answer": answer, "score": score,
|
| "passed": passed, "type": qtype,
|
| })
|
|
|
| total = direct_correct + contrastive_correct
|
| total_q = direct_total + contrastive_total
|
| print("=" * 60)
|
| print(f"DIRECT: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})")
|
| print(f"CONTRASTIVE: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})")
|
| print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
|
| print(f"vs 4e: {'IMPROVED' if total > 14 else ('SAME' if total == 14 else 'REGRESSED')} (was 14/15)")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[9/10] Saving checkpoint...")
|
| version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
| path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}")
|
|
|
| for entry in os.listdir(CHECKPOINT_DIR):
|
| full = os.path.join(CHECKPOINT_DIR, entry)
|
| if os.path.isdir(full) and entry.startswith("claudia_"):
|
| import shutil
|
| size_gb = sum(
|
| os.path.getsize(os.path.join(dp, f))
|
| for dp, _, fns in os.walk(full) for f in fns
|
| ) / 1e9
|
| print(f" Removing old: {entry} ({size_gb:.1f} GB)")
|
| shutil.rmtree(full)
|
|
|
| mm.merge_and_save(path)
|
|
|
| with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f:
|
| json.dump(all_training_data, f)
|
| with open(os.path.join(path, "replay_buffer.json"), 'w') as f:
|
| json.dump(all_training_data, f)
|
| with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f:
|
| json.dump(quiz_pairs_log, f)
|
|
|
| print(f" Caching teacher logits ({len(quiz_pairs_log)} pairs)...")
|
| new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log)
|
| cache_path = os.path.join(path, "teacher_cache.pt")
|
| torch.save(new_teacher_cache, cache_path)
|
| size_mb = os.path.getsize(cache_path) / 1e6
|
| print(f" Teacher cache: {len(new_teacher_cache)} items, {size_mb:.1f} MB")
|
|
|
|
|
| print("\n[10/10] CROSS-SESSION TEST: Loading fresh from checkpoint...")
|
| print("=" * 60)
|
|
|
|
|
| del mm
|
| torch.cuda.empty_cache()
|
|
|
| mm2 = ModelManager(model_path=path, checkpoint_path=path)
|
| mm2.load()
|
| quiz_gen2 = QuizGenerator(mm2)
|
|
|
|
|
| teacher_cache2 = torch.load(cache_path, map_location="cpu", weights_only=False)
|
| print(f" Session 5 consolidation: {len(teacher_cache2)} items")
|
| loss5 = mm2.distill(teacher_cache2, epochs=CONSOLIDATION_EPOCHS)
|
| print(f" Consolidation loss: {loss5:.4f}")
|
|
|
|
|
| print("\n --- Session 5: Teaching NEW facts ---")
|
| s5_training = []
|
| s5_quizzes = []
|
| conv_buf = []
|
|
|
| for i, user_msg in enumerate(SESSION5_MESSAGES):
|
| print(f"\n S5 Message {i+1}/{len(SESSION5_MESSAGES)}: {user_msg}")
|
| conv_buf.append({"role": "user", "content": user_msg})
|
| response = mm2.generate(conv_buf)
|
| conv_buf.append({"role": "assistant", "content": response})
|
| print(f" Claudia: {response[:100]}...")
|
|
|
| quizzes = quiz_gen2.generate(user_msg, response)
|
| print(f" Quizzes: {len(quizzes)}")
|
|
|
| exchange = {"messages": [
|
| {"role": "user", "content": user_msg},
|
| {"role": "assistant", "content": response},
|
| ]}
|
|
|
|
|
| positive = [qp for qp in quizzes if not qp["messages"][1]["content"].lower().startswith("no.")]
|
| contrastive = [qp for qp in quizzes if qp["messages"][1]["content"].lower().startswith("no.")]
|
|
|
|
|
| phase1 = [exchange] + positive
|
| loss1 = mm2.absorb(phase1)
|
| print(f" Phase 1: {len(phase1)} items, loss={loss1:.4f}")
|
|
|
| if contrastive:
|
| confused = quick_verify_entities(mm2, quiz_gen2.known_entities)
|
| if confused:
|
| targeted = [qp for qp in contrastive
|
| if any(n.lower() in (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| for n in confused)]
|
| if targeted:
|
| loss2 = mm2.absorb(targeted)
|
| print(f" Phase 2: {len(targeted)} targeted, loss={loss2:.4f}")
|
|
|
| s5_training.append(exchange)
|
| s5_training.extend(quizzes)
|
| s5_quizzes.extend(quizzes)
|
|
|
|
|
| print("\n --- Session 5 RECALL TEST (old + new facts) ---")
|
| print("=" * 60)
|
|
|
| old_correct = 0
|
| old_total = 0
|
| new_correct = 0
|
| new_total = 0
|
| s5_results = []
|
|
|
| for question, keywords, qtype in SESSION5_RECALL:
|
| answer = mm2.generate([{"role": "user", "content": question}], max_new_tokens=200)
|
| score = score_answer(answer, keywords)
|
| passed = score >= 0.5
|
|
|
| if qtype == "old":
|
| old_total += 1
|
| if passed: old_correct += 1
|
| else:
|
| new_total += 1
|
| if passed: new_correct += 1
|
|
|
| status = "PASS" if passed else "FAIL"
|
| print(f" [{status}] ({qtype}) {question}")
|
| print(f" Got: {answer[:120]}")
|
|
|
| s5_results.append({
|
| "question": question, "expected": keywords,
|
| "answer": answer, "score": score,
|
| "passed": passed, "type": qtype,
|
| })
|
|
|
| s5_total = old_correct + new_correct
|
| s5_total_q = old_total + new_total
|
| print("=" * 60)
|
| print(f"OLD FACTS (session 4): {old_correct}/{old_total} ({old_correct/old_total:.0%})")
|
| print(f"NEW FACTS (session 5): {new_correct}/{new_total} ({new_correct/new_total:.0%})")
|
| print(f"TOTAL: {s5_total}/{s5_total_q} ({s5_total/s5_total_q:.0%})")
|
| print("=" * 60)
|
|
|
|
|
| results_data = {
|
| "session": "4f",
|
| "session4_recall": f"{total}/{total_q}",
|
| "session4_direct": f"{direct_correct}/{direct_total}",
|
| "session4_contrastive": f"{contrastive_correct}/{contrastive_total}",
|
| "session5_old_recall": f"{old_correct}/{old_total}",
|
| "session5_new_recall": f"{new_correct}/{new_total}",
|
| "session5_total": f"{s5_total}/{s5_total_q}",
|
| "personality_score": p_score,
|
| "total_quizzes": total_quiz_count,
|
| "checkpoint": path,
|
| "timestamp": datetime.now().isoformat(),
|
| "session4_results": results,
|
| "session5_results": s5_results,
|
| }
|
| results_path = os.path.join(LOG_DIR, "session4f_results.json")
|
| with open(results_path, 'w') as f:
|
| json.dump(results_data, f, indent=2)
|
| print(f"\nResults saved: {results_path}")
|
| print("\nSESSION 4f + CROSS-SESSION TEST COMPLETE.")
|
|
|
|
|
| if __name__ == "__main__":
|
| run_session4f()
|
|
|