claudia-memory-pipeline / tests /test_session4f.py
msrcam's picture
Upload tests/test_session4f.py with huggingface_hub
1cb8031 verified
"""
Session 4f — Phase 3 Stubborn-Fact Retry + Cross-Session Persistence Test
=========================================================================
Session 4e scored 14/15 (93%). Only failure: Marcus city (San Diego instead
of Seattle). Phase 2 detected the confusion every time but couldn't fix it.
NEW: Phase 3 — after final absorption, re-verify ALL entities. For any that
still fail, create a micro-focused training batch with ONLY that specific
fact repeated 3-5x. This is brute-force but targeted — only fires on the
last stubborn associations that two-phase couldn't fix.
Also: after saving, immediately loads from checkpoint as a fresh "session 5"
and tests all 4e facts survive consolidation. This proves cascade distillation
works across sessions.
Run: python3 -u test_session4f.py 2>&1 | tee /workspace/logs/distill2_s4f.log
"""
import json
import os
import random
import sys
import time
import torch
from datetime import datetime
sys.path.insert(0, "/workspace")
from persistent_absorber import (
ModelManager, QuizGenerator, check_personality,
CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE
)
CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0412"
LOG_DIR = "/workspace/logs"
CHECKPOINT_DIR = "/workspace/checkpoints"
SESSION4_MESSAGES = [
"My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.",
"My cousin Marcus is an architect in Seattle. He designed the new library there.",
"I work at a startup called NovaMind. I'm the CTO. We're based in Austin.",
"My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.",
"Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.",
]
# Session 5 NEW facts — different people, tests cross-session retention
SESSION5_MESSAGES = [
"My neighbor Dave is a firefighter in Denver. He's been doing it for 15 years.",
"My colleague Rina is a data scientist at Google. She's based in Mountain View.",
"I just bought a Tesla Model 3. It's midnight blue.",
]
RECALL_QUESTIONS = [
("What does Matt's friend Jordan do?", ["marine biologist"], "direct"),
("Where does Matt's sister Elena live?", ["portland"], "direct"),
("What is Matt's cousin Marcus's job?", ["architect"], "direct"),
("Where does Marcus live?", ["seattle"], "direct"),
("What does Priya do?", ["neurosurgeon"], "direct"),
("Where does Priya live?", ["chicago"], "direct"),
("Where does Matt work?", ["novamind"], "direct"),
("What is Matt's job title?", ["cto"], "direct"),
("What is Elena's cat's name?", ["mochi"], "direct"),
("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"),
("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"),
("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"),
("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"),
("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"),
("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"),
]
# Session 5 recall — tests both OLD (session 4) and NEW facts
SESSION5_RECALL = [
# Old facts (should survive cross-session)
("What does Jordan do?", ["marine biologist"], "old"),
("Where does Elena live?", ["portland"], "old"),
("What is Marcus's job?", ["architect"], "old"),
("Where does Marcus live?", ["seattle"], "old"),
("What does Priya do?", ["neurosurgeon"], "old"),
("Where does Matt work?", ["novamind"], "old"),
# New facts
("What does Matt's neighbor Dave do?", ["firefighter"], "new"),
("Where does Dave live?", ["denver"], "new"),
("What does Rina do?", ["data scientist"], "new"),
("Where does Rina work?", ["google"], "new"),
("What car did Matt buy?", ["tesla", "model 3"], "new"),
]
def score_answer(answer, expected_keywords):
answer_lower = answer.lower()
hits = sum(1 for k in expected_keywords if k in answer_lower)
return hits / len(expected_keywords)
def quick_verify_entities(mm, known_entities):
"""Returns set of confused entity names."""
confused = set()
for name, info in known_entities.items():
if info.get("job"):
q = f"What does Matt's {info['relationship']} {name} do?"
ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["job"].lower() not in ans.lower():
confused.add(name)
print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}")
if info.get("city"):
q = f"Where does {name} live?"
ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["city"].lower() not in ans.lower():
confused.add(name)
print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}")
return confused
def phase3_stubborn_retry(mm, quiz_gen, confused_entities, all_contrastive, max_retries=3):
"""Phase 3: Micro-focused retry on stubborn facts.
For each confused entity, create a concentrated batch of:
- 3x the entity's correct fact (positive reinforcement)
- 2x the relevant contrastive pairs
Train with extra emphasis. Repeat up to max_retries times."""
for retry in range(max_retries):
if not confused_entities:
print(f" Phase 3: All entities correct after {retry} retries!")
return
print(f" Phase 3 retry {retry+1}/{max_retries}: fixing {confused_entities}")
retry_batch = []
for name in confused_entities:
info = quiz_gen.known_entities.get(name, {})
if not info:
continue
# Positive reinforcement: repeat the correct fact 3x
if info.get("job"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"What does Matt's {info['relationship']} {name} do?"},
{"role": "assistant", "content": f"Matt's {info['relationship']} {name} is a {info['job']}."},
]})
if info.get("city"):
for _ in range(3):
retry_batch.append({"messages": [
{"role": "user", "content": f"Where does {name} live?"},
{"role": "assistant", "content": f"{name} lives in {info['city']}. {name} is Matt's {info['relationship']}."},
]})
# Add relevant contrastive pairs
for qp in all_contrastive:
full_text = (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
if name.lower() in full_text:
retry_batch.append(qp)
if retry_batch:
loss = mm.absorb(retry_batch)
print(f" Phase 3: Trained {len(retry_batch)} items, loss={loss:.4f}")
# Re-verify
still_confused = set()
for name in confused_entities:
info = quiz_gen.known_entities.get(name, {})
if info.get("job"):
ans = mm.generate([{"role": "user", "content": f"What does {name} do?"}], max_new_tokens=100)
if info["job"].lower() not in ans.lower():
still_confused.add(name)
if info.get("city"):
ans = mm.generate([{"role": "user", "content": f"Where does {name} live?"}], max_new_tokens=100)
if info["city"].lower() not in ans.lower():
still_confused.add(name)
confused_entities = still_confused
if confused_entities:
print(f" Phase 3: Still confused after {max_retries} retries: {confused_entities}")
def run_session4f():
print("=" * 60)
print("SESSION 4f: Phase 3 Stubborn Retry + Cross-Session Test")
print(f"Time: {datetime.now().isoformat()}")
print("=" * 60)
# ── Load model ──
print("\n[1/10] Loading model from checkpoint...")
mm = ModelManager(model_path=CHECKPOINT, checkpoint_path=CHECKPOINT)
mm.load()
quiz_gen = QuizGenerator(mm)
# ── Load existing data ──
print("\n[2/10] Loading existing data...")
replay_path = os.path.join(LOG_DIR, "replay_buffer.json")
quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json")
all_training_data = []
quiz_pairs_log = []
if os.path.exists(replay_path):
with open(replay_path, 'r') as f:
all_training_data = json.load(f)
print(f" Replay buffer: {len(all_training_data)} examples")
if os.path.exists(quiz_log_path):
with open(quiz_log_path, 'r') as f:
quiz_pairs_log = json.load(f)
print(f" Quiz pairs: {len(quiz_pairs_log)}")
# ── Consolidation ──
print("\n[3/10] Cascade consolidation...")
teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt")
if os.path.exists(teacher_cache_path):
teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False)
print(f" Teacher cache: {len(teacher_cache)} items")
loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS)
print(f" Consolidation loss: {loss:.4f}")
else:
print(" No teacher cache — skipping")
# ── Personality ──
print("\n[4/10] Personality check...")
p_score = check_personality(mm)
# ── Feed facts (same as 4e) ──
print("\n[5/10] Teaching facts (two-phase + focused contrastive)...")
conversation_buffer = []
session_quizzes = []
session_positive = []
session_contrastive = []
total_quiz_count = 0
for i, user_msg in enumerate(SESSION4_MESSAGES):
print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---")
print(f" Matt: {user_msg}")
conversation_buffer.append({"role": "user", "content": user_msg})
if len(conversation_buffer) > 20:
conversation_buffer = conversation_buffer[-20:]
response = mm.generate(conversation_buffer)
conversation_buffer.append({"role": "assistant", "content": response})
print(f" Claudia: {response[:120]}...")
quizzes = quiz_gen.generate(user_msg, response)
total_quiz_count += len(quizzes)
print(f" Quizzes: {len(quizzes)} (total: {total_quiz_count}), Entities: {list(quiz_gen.known_entities.keys())}")
positive_batch = []
contrastive_batch = []
for qp in quizzes:
a = qp["messages"][1]["content"]
if a.lower().startswith("no."):
contrastive_batch.append(qp)
else:
positive_batch.append(qp)
print(f" Pos: {len(positive_batch)}, Contrastive: {len(contrastive_batch)}")
exchange = {"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": response},
]}
all_training_data.append(exchange)
all_training_data.extend(quizzes)
quiz_pairs_log.extend(quizzes)
session_quizzes.extend(quizzes)
session_positive.extend(positive_batch)
session_contrastive.extend(contrastive_batch)
# Phase 1: Positive
phase1_data = [exchange] + positive_batch
old_sample = random.sample(
all_training_data[:-len(quizzes)-1],
min(6, max(0, len(all_training_data) - len(quizzes) - 1))
) if len(all_training_data) > len(quizzes) + 7 else []
phase1_data.extend(old_sample)
if quiz_gen.known_entities:
phase1_data = ModelManager.cluster_by_entity(
phase1_data, list(quiz_gen.known_entities.keys()))
loss1 = mm.absorb(phase1_data)
print(f" Phase 1: {len(phase1_data)} examples, loss={loss1:.4f}")
# Phase 2: Targeted contrastive
if contrastive_batch:
confused = quick_verify_entities(mm, quiz_gen.known_entities)
if confused:
targeted = []
for qp in contrastive_batch:
full_text = (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
if any(name.lower() in full_text for name in confused):
targeted.append(qp)
if targeted:
loss2 = mm.absorb(targeted)
print(f" Phase 2: {len(targeted)}/{len(contrastive_batch)} targeted, loss={loss2:.4f}")
else:
print(f" Phase 2: No matching pairs — skipping")
else:
print(f" Phase 2: All correct — skipping!")
# ── Final two-phase ──
print("\n[6/10] Final two-phase absorption...")
old_sample = random.sample(
all_training_data[:-len(session_quizzes)],
min(12, max(0, len(all_training_data) - len(session_quizzes)))
) if len(all_training_data) > len(session_quizzes) + 12 else []
final_positive = session_positive + old_sample
if quiz_gen.known_entities:
final_positive = ModelManager.cluster_by_entity(
final_positive, list(quiz_gen.known_entities.keys()))
loss_fp = mm.absorb(final_positive)
print(f" Final Phase 1: {len(final_positive)} examples, loss={loss_fp:.4f}")
confused = quick_verify_entities(mm, quiz_gen.known_entities)
if confused and session_contrastive:
targeted = []
for qp in session_contrastive:
full_text = (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
if any(name.lower() in full_text for name in confused):
targeted.append(qp)
if targeted:
loss_fc = mm.absorb(targeted)
print(f" Final Phase 2: {len(targeted)} targeted, loss={loss_fc:.4f}")
# ── PHASE 3: Stubborn-fact retry ──
print("\n[7/10] PHASE 3: Stubborn-fact retry...")
confused = quick_verify_entities(mm, quiz_gen.known_entities)
if confused:
phase3_stubborn_retry(mm, quiz_gen, confused, session_contrastive, max_retries=3)
else:
print(" Phase 3: Not needed — all entities correct!")
# ── Recall test (session 4 facts) ──
print("\n[8/10] SESSION 4 RECALL TEST")
print("=" * 60)
direct_correct = 0
direct_total = 0
contrastive_correct = 0
contrastive_total = 0
results = []
for question, keywords, qtype in RECALL_QUESTIONS:
answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
score = score_answer(answer, keywords)
passed = score >= 0.5
if qtype == "direct":
direct_total += 1
if passed: direct_correct += 1
else:
contrastive_total += 1
if passed: contrastive_correct += 1
status = "PASS" if passed else "FAIL"
print(f" [{status}] ({qtype}) {question}")
print(f" Got: {answer[:120]}")
print(f" Score: {score:.2f}")
results.append({
"question": question, "expected": keywords,
"answer": answer, "score": score,
"passed": passed, "type": qtype,
})
total = direct_correct + contrastive_correct
total_q = direct_total + contrastive_total
print("=" * 60)
print(f"DIRECT: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})")
print(f"CONTRASTIVE: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})")
print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
print(f"vs 4e: {'IMPROVED' if total > 14 else ('SAME' if total == 14 else 'REGRESSED')} (was 14/15)")
print("=" * 60)
# ── Save checkpoint ──
print("\n[9/10] Saving checkpoint...")
version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}")
for entry in os.listdir(CHECKPOINT_DIR):
full = os.path.join(CHECKPOINT_DIR, entry)
if os.path.isdir(full) and entry.startswith("claudia_"):
import shutil
size_gb = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, fns in os.walk(full) for f in fns
) / 1e9
print(f" Removing old: {entry} ({size_gb:.1f} GB)")
shutil.rmtree(full)
mm.merge_and_save(path)
with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
with open(os.path.join(path, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f:
json.dump(quiz_pairs_log, f)
print(f" Caching teacher logits ({len(quiz_pairs_log)} pairs)...")
new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log)
cache_path = os.path.join(path, "teacher_cache.pt")
torch.save(new_teacher_cache, cache_path)
size_mb = os.path.getsize(cache_path) / 1e6
print(f" Teacher cache: {len(new_teacher_cache)} items, {size_mb:.1f} MB")
# ── CROSS-SESSION TEST: Session 5 ──
print("\n[10/10] CROSS-SESSION TEST: Loading fresh from checkpoint...")
print("=" * 60)
# Destroy current model, reload from checkpoint (simulates new session)
del mm
torch.cuda.empty_cache()
mm2 = ModelManager(model_path=path, checkpoint_path=path)
mm2.load()
quiz_gen2 = QuizGenerator(mm2)
# Consolidation from teacher cache (same as real session start)
teacher_cache2 = torch.load(cache_path, map_location="cpu", weights_only=False)
print(f" Session 5 consolidation: {len(teacher_cache2)} items")
loss5 = mm2.distill(teacher_cache2, epochs=CONSOLIDATION_EPOCHS)
print(f" Consolidation loss: {loss5:.4f}")
# Feed session 5 NEW facts
print("\n --- Session 5: Teaching NEW facts ---")
s5_training = []
s5_quizzes = []
conv_buf = []
for i, user_msg in enumerate(SESSION5_MESSAGES):
print(f"\n S5 Message {i+1}/{len(SESSION5_MESSAGES)}: {user_msg}")
conv_buf.append({"role": "user", "content": user_msg})
response = mm2.generate(conv_buf)
conv_buf.append({"role": "assistant", "content": response})
print(f" Claudia: {response[:100]}...")
quizzes = quiz_gen2.generate(user_msg, response)
print(f" Quizzes: {len(quizzes)}")
exchange = {"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": response},
]}
# Separate positive/contrastive
positive = [qp for qp in quizzes if not qp["messages"][1]["content"].lower().startswith("no.")]
contrastive = [qp for qp in quizzes if qp["messages"][1]["content"].lower().startswith("no.")]
# Two-phase absorption
phase1 = [exchange] + positive
loss1 = mm2.absorb(phase1)
print(f" Phase 1: {len(phase1)} items, loss={loss1:.4f}")
if contrastive:
confused = quick_verify_entities(mm2, quiz_gen2.known_entities)
if confused:
targeted = [qp for qp in contrastive
if any(n.lower() in (qp["messages"][0]["content"] + " " +
qp["messages"][1]["content"]).lower()
for n in confused)]
if targeted:
loss2 = mm2.absorb(targeted)
print(f" Phase 2: {len(targeted)} targeted, loss={loss2:.4f}")
s5_training.append(exchange)
s5_training.extend(quizzes)
s5_quizzes.extend(quizzes)
# Session 5 recall: test BOTH old and new facts
print("\n --- Session 5 RECALL TEST (old + new facts) ---")
print("=" * 60)
old_correct = 0
old_total = 0
new_correct = 0
new_total = 0
s5_results = []
for question, keywords, qtype in SESSION5_RECALL:
answer = mm2.generate([{"role": "user", "content": question}], max_new_tokens=200)
score = score_answer(answer, keywords)
passed = score >= 0.5
if qtype == "old":
old_total += 1
if passed: old_correct += 1
else:
new_total += 1
if passed: new_correct += 1
status = "PASS" if passed else "FAIL"
print(f" [{status}] ({qtype}) {question}")
print(f" Got: {answer[:120]}")
s5_results.append({
"question": question, "expected": keywords,
"answer": answer, "score": score,
"passed": passed, "type": qtype,
})
s5_total = old_correct + new_correct
s5_total_q = old_total + new_total
print("=" * 60)
print(f"OLD FACTS (session 4): {old_correct}/{old_total} ({old_correct/old_total:.0%})")
print(f"NEW FACTS (session 5): {new_correct}/{new_total} ({new_correct/new_total:.0%})")
print(f"TOTAL: {s5_total}/{s5_total_q} ({s5_total/s5_total_q:.0%})")
print("=" * 60)
# Save all results
results_data = {
"session": "4f",
"session4_recall": f"{total}/{total_q}",
"session4_direct": f"{direct_correct}/{direct_total}",
"session4_contrastive": f"{contrastive_correct}/{contrastive_total}",
"session5_old_recall": f"{old_correct}/{old_total}",
"session5_new_recall": f"{new_correct}/{new_total}",
"session5_total": f"{s5_total}/{s5_total_q}",
"personality_score": p_score,
"total_quizzes": total_quiz_count,
"checkpoint": path,
"timestamp": datetime.now().isoformat(),
"session4_results": results,
"session5_results": s5_results,
}
results_path = os.path.join(LOG_DIR, "session4f_results.json")
with open(results_path, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\nResults saved: {results_path}")
print("\nSESSION 4f + CROSS-SESSION TEST COMPLETE.")
if __name__ == "__main__":
run_session4f()