claudia-memory-pipeline / tests /test_session4d.py
msrcam's picture
Upload tests/test_session4d.py with huggingface_hub
19205d9 verified
"""
Session 4d — Two-Phase Absorption + Quiz Diversity + Entity Clustering
======================================================================
Tests 3 improvements over session 4b (best: 11/15 = 73%, contrastive 6/6 = 100%):
1. TWO-PHASE ABSORPTION: Positive facts first, then targeted contrastive
correction only on entities that fail verification. Prevents contrastive
gradients from fighting positive ones.
2. QUIZ DIVERSITY: Multiple question formats per fact (randomly varied).
Creates multiple retrieval paths to same fact, strengthening direct recall
without increasing total quiz count.
3. ENTITY CLUSTERING: Group training data by entity. Train all Jordan facts
together, then all Elena facts, etc. Reduces cross-entity contamination.
Key constraint: Keep total quizzes ~35-40 (session 4c proved >40 hurts).
Run: python3 test_session4d.py 2>&1 | tee /workspace/logs/distill2_s4d.log
"""
import json
import os
import random
import sys
import time
import torch
from datetime import datetime
sys.path.insert(0, "/workspace")
from persistent_absorber import (
ModelManager, QuizGenerator, check_personality,
CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE
)
CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0049"
LOG_DIR = "/workspace/logs"
CHECKPOINT_DIR = "/workspace/checkpoints"
# Same facts as session 4/4b/4c
SESSION4_MESSAGES = [
"My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.",
"My cousin Marcus is an architect in Seattle. He designed the new library there.",
"I work at a startup called NovaMind. I'm the CTO. We're based in Austin.",
"My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.",
"Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.",
]
RECALL_QUESTIONS = [
("What does Matt's friend Jordan do?", ["marine biologist"], "direct"),
("Where does Matt's sister Elena live?", ["portland"], "direct"),
("What is Matt's cousin Marcus's job?", ["architect"], "direct"),
("Where does Marcus live?", ["seattle"], "direct"),
("What does Priya do?", ["neurosurgeon"], "direct"),
("Where does Priya live?", ["chicago"], "direct"),
("Where does Matt work?", ["novamind"], "direct"),
("What is Matt's job title?", ["cto"], "direct"),
("What is Elena's cat's name?", ["mochi"], "direct"),
("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"),
("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"),
("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"),
("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"),
("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"),
("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"),
]
def score_answer(answer, expected_keywords):
answer_lower = answer.lower()
hits = sum(1 for k in expected_keywords if k in answer_lower)
return hits / len(expected_keywords)
def quick_verify_entities(mm, known_entities):
"""Quick verification: test model on key entity facts.
Returns set of entity names that the model gets WRONG."""
confused = set()
for name, info in known_entities.items():
if info.get("job"):
q = f"What does Matt's {info['relationship']} {name} do?"
ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["job"].lower() not in ans.lower():
confused.add(name)
print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}")
if info.get("city"):
q = f"Where does {name} live?"
ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
if info["city"].lower() not in ans.lower():
confused.add(name)
print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}")
return confused
def run_session4d():
print("=" * 60)
print("SESSION 4d: Two-Phase + Diversity + Clustering")
print(f"Time: {datetime.now().isoformat()}")
print("=" * 60)
# ── Load model ──
print("\n[1/8] Loading model from checkpoint...")
mm = ModelManager(
model_path=CHECKPOINT,
checkpoint_path=CHECKPOINT,
)
mm.load()
quiz_gen = QuizGenerator(mm)
# ── Load existing data ──
print("\n[2/8] Loading existing data...")
replay_path = os.path.join(LOG_DIR, "replay_buffer.json")
quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json")
all_training_data = []
quiz_pairs_log = []
if os.path.exists(replay_path):
with open(replay_path, 'r') as f:
all_training_data = json.load(f)
print(f" Replay buffer: {len(all_training_data)} examples")
if os.path.exists(quiz_log_path):
with open(quiz_log_path, 'r') as f:
quiz_pairs_log = json.load(f)
print(f" Quiz pairs: {len(quiz_pairs_log)}")
# ── Consolidation ──
print("\n[3/8] Cascade consolidation from teacher cache...")
teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt")
if os.path.exists(teacher_cache_path):
teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False)
print(f" Teacher cache: {len(teacher_cache)} items")
loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS)
print(f" Consolidation loss: {loss:.4f}")
else:
print(" No teacher cache found — skipping")
# ── Personality check ──
print("\n[4/8] Personality check...")
p_score = check_personality(mm)
# ── Feed facts with TWO-PHASE absorption ──
print("\n[5/8] Teaching facts with TWO-PHASE absorption + diversity + clustering...")
conversation_buffer = []
session_quizzes = []
session_positive = [] # Positive quizzes for this session
session_contrastive = [] # Contrastive quizzes for this session
contrastive_count = 0
total_quiz_count = 0
for i, user_msg in enumerate(SESSION4_MESSAGES):
print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---")
print(f" Matt: {user_msg}")
conversation_buffer.append({"role": "user", "content": user_msg})
if len(conversation_buffer) > 20:
conversation_buffer = conversation_buffer[-20:]
response = mm.generate(conversation_buffer)
conversation_buffer.append({"role": "assistant", "content": response})
print(f" Claudia: {response[:150]}...")
# Generate quizzes (with diversity from updated QuizGenerator)
quizzes = quiz_gen.generate(user_msg, response)
total_quiz_count += len(quizzes)
print(f" Quizzes generated: {len(quizzes)} (session total: {total_quiz_count})")
print(f" Known entities: {list(quiz_gen.known_entities.keys())}")
# Separate positive from contrastive
positive_batch = []
contrastive_batch = []
for qp in quizzes:
q = qp["messages"][0]["content"]
a = qp["messages"][1]["content"]
is_contrastive = a.lower().startswith("no.")
if is_contrastive:
contrastive_batch.append(qp)
contrastive_count += 1
tag = " [CONTRASTIVE]"
else:
positive_batch.append(qp)
tag = ""
# Print first few
if len(positive_batch) + len(contrastive_batch) <= 6:
print(f" Q: {q}{tag}")
print(f" A: {a[:120]}")
# Build the exchange item
exchange = {"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": response},
]}
# Add to tracking lists
all_training_data.append(exchange)
all_training_data.extend(quizzes)
quiz_pairs_log.extend(quizzes)
session_quizzes.extend(quizzes)
session_positive.extend(positive_batch)
session_contrastive.extend(contrastive_batch)
# ── TWO-PHASE ABSORPTION (per-message) ──
# Phase 1: Positive facts + exchange + small replay
phase1_data = [exchange] + positive_batch
old_sample = random.sample(
all_training_data[:-len(quizzes)-1],
min(6, max(0, len(all_training_data) - len(quizzes) - 1))
) if len(all_training_data) > len(quizzes) + 7 else []
phase1_data.extend(old_sample)
# ENTITY CLUSTERING: order phase1 data by entity
if quiz_gen.known_entities:
phase1_data = ModelManager.cluster_by_entity(
phase1_data, list(quiz_gen.known_entities.keys()))
t0 = time.time()
loss1 = mm.absorb(phase1_data)
elapsed1 = time.time() - t0
print(f" Phase 1 (positive): {len(phase1_data)} examples, {elapsed1:.1f}s, loss={loss1:.4f}")
# Phase 2: Targeted contrastive correction
if contrastive_batch:
# Quick verify which entities are confused
print(f" Phase 2: Verifying entities before contrastive correction...")
confused = quick_verify_entities(mm, quiz_gen.known_entities)
if confused:
# Only train contrastive pairs that mention confused entities
targeted = []
for qp in contrastive_batch:
q_lower = qp["messages"][0]["content"].lower()
if any(name.lower() in q_lower for name in confused):
targeted.append(qp)
if targeted:
t0 = time.time()
loss2 = mm.absorb(targeted)
elapsed2 = time.time() - t0
print(f" Phase 2 (contrastive): {len(targeted)}/{len(contrastive_batch)} targeted, "
f"{elapsed2:.1f}s, loss={loss2:.4f}")
else:
print(f" Phase 2: No contrastive pairs for confused entities — skipping")
else:
print(f" Phase 2: All entities correct — skipping contrastive training!")
print(f"\n Session 4d totals:")
print(f" Quizzes: {total_quiz_count}")
print(f" Positive: {len(session_positive)}")
print(f" Contrastive: {contrastive_count}")
print(f" Known entities: {list(quiz_gen.known_entities.keys())}")
# ── Final two-phase absorption ──
print("\n[6/8] Final two-phase absorption pass...")
# Phase 1: All positive quizzes + replay
old_sample = random.sample(
all_training_data[:-len(session_quizzes)],
min(12, max(0, len(all_training_data) - len(session_quizzes)))
) if len(all_training_data) > len(session_quizzes) + 12 else []
final_positive = session_positive + old_sample
if quiz_gen.known_entities:
final_positive = ModelManager.cluster_by_entity(
final_positive, list(quiz_gen.known_entities.keys()))
loss_fp = mm.absorb(final_positive)
print(f" Final Phase 1 (positive): {len(final_positive)} examples, loss={loss_fp:.4f}")
# Phase 2: Verify then targeted contrastive
print(f" Final Phase 2: Verifying entities...")
confused = quick_verify_entities(mm, quiz_gen.known_entities)
if confused and session_contrastive:
targeted = []
for qp in session_contrastive:
q_lower = qp["messages"][0]["content"].lower()
if any(name.lower() in q_lower for name in confused):
targeted.append(qp)
if targeted:
loss_fc = mm.absorb(targeted)
print(f" Final Phase 2 (contrastive): {len(targeted)} targeted, loss={loss_fc:.4f}")
else:
print(f" Final Phase 2: No contrastive pairs needed")
else:
status = "no confusion detected" if not confused else "no contrastive data"
print(f" Final Phase 2: Skipped ({status})")
# ── Recall test ──
print("\n[7/8] RECALL TEST")
print("=" * 60)
direct_correct = 0
direct_total = 0
contrastive_correct = 0
contrastive_total = 0
results = []
for question, keywords, qtype in RECALL_QUESTIONS:
answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
score = score_answer(answer, keywords)
passed = score >= 0.5
if qtype == "direct":
direct_total += 1
if passed:
direct_correct += 1
else:
contrastive_total += 1
if passed:
contrastive_correct += 1
status = "PASS" if passed else "FAIL"
print(f" [{status}] ({qtype}) {question}")
print(f" Expected: {keywords}")
print(f" Got: {answer[:150]}")
print(f" Score: {score:.2f}")
print()
results.append({
"question": question,
"expected": keywords,
"answer": answer,
"score": score,
"passed": passed,
"type": qtype,
})
print("=" * 60)
print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})")
print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})")
total = direct_correct + contrastive_correct
total_q = direct_total + contrastive_total
print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
s4b_comparison = "IMPROVED" if total > 11 else ("SAME" if total == 11 else "REGRESSED")
print(f"vs SESSION 4b: {s4b_comparison} (was 11/15 = 73%)")
print("=" * 60)
# ── Save ──
print("\n[8/8] Saving checkpoint...")
version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}")
for entry in os.listdir(CHECKPOINT_DIR):
full = os.path.join(CHECKPOINT_DIR, entry)
if os.path.isdir(full) and entry.startswith("claudia_"):
import shutil
size_gb = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, fns in os.walk(full) for f in fns
) / 1e9
print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
shutil.rmtree(full)
mm.merge_and_save(path)
with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
with open(os.path.join(path, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f:
json.dump(quiz_pairs_log, f)
print(f" Replay buffer: {len(all_training_data)} examples")
print(f" Quiz pairs log: {len(quiz_pairs_log)}")
print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...")
new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log)
cache_path = os.path.join(path, "teacher_cache.pt")
torch.save(new_teacher_cache, cache_path)
size_mb = os.path.getsize(cache_path) / 1e6
print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)")
results_data = {
"session": "4d",
"improvements": [
"two-phase absorption (positive first, targeted contrastive second)",
"quiz format diversity (random question formats per fact)",
"entity-clustered absorption ordering",
],
"checkpoint": path,
"direct_recall": f"{direct_correct}/{direct_total}",
"contrastive_recall": f"{contrastive_correct}/{contrastive_total}",
"total_recall": f"{total}/{total_q}",
"vs_session4b": s4b_comparison,
"personality_score": p_score,
"session_quizzes": len(session_quizzes),
"positive_quizzes": len(session_positive),
"contrastive_quizzes": contrastive_count,
"total_quiz_pairs": len(quiz_pairs_log),
"total_replay": len(all_training_data),
"known_entities": {k: v for k, v in quiz_gen.known_entities.items()},
"timestamp": datetime.now().isoformat(),
"results": results,
}
results_path = os.path.join(LOG_DIR, "session4d_results.json")
with open(results_path, 'w') as f:
json.dump(results_data, f, indent=2)
print(f" Results saved: {results_path}")
print(f"\n Next run: --checkpoint {path}")
print("\nSESSION 4d COMPLETE.")
if __name__ == "__main__":
run_session4d()