claudia-memory-pipeline / tests /test_session4.py
msrcam's picture
Upload tests/test_session4.py with huggingface_hub
36d805b verified
"""
Session 4 Automated Test β€” Contrastive Disambiguation
=====================================================
Tests the new contrastive quiz generation against entity confusion.
Plan:
1. Load from checkpoint claudia_session_20260322_2206
2. Run consolidation distillation (locks in sessions 1-3)
3. Feed 5 multi-person facts designed to trigger entity confusion
4. Generate quizzes including contrastive pairs
5. Absorb all data
6. Test recall on:
a) Direct questions (same as before)
b) Cross-entity confusion questions ("Is X a [Y's job]?")
7. Save checkpoint + teacher cache
8. Report results
Run: python3 test_session4.py 2>&1 | tee /workspace/logs/distill2_s4.log
"""
import json
import os
import sys
import time
import torch
from datetime import datetime
# Add workspace to path
sys.path.insert(0, "/workspace")
from persistent_absorber import (
ModelManager, QuizGenerator, check_personality,
CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE
)
CHECKPOINT = "/workspace/checkpoints/claudia_session_20260322_2206"
LOG_DIR = "/workspace/logs"
CHECKPOINT_DIR = "/workspace/checkpoints"
# ═══════════════════════════════════════════════════════════════════════
# SESSION 4 FACTS β€” deliberately designed to trigger entity confusion
# Multiple people with jobs + cities that could be swapped
# ═══════════════════════════════════════════════════════════════════════
SESSION4_MESSAGES = [
# Message 1: Two people, similar structure, different details
"My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.",
# Message 2: Add a third person with overlapping city
"My cousin Marcus is an architect in Seattle. He designed the new library there.",
# Message 3: Matt's own details (to contrast against friends)
"I work at a startup called NovaMind. I'm the CTO. We're based in Austin.",
# Message 4: More people with specific details
"My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.",
# Message 5: A fact that connects to earlier entities
"Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.",
]
# ═══════════════════════════════════════════════════════════════════════
# RECALL TEST QUESTIONS β€” direct + contrastive
# ═══════════════════════════════════════════════════════════════════════
RECALL_QUESTIONS = [
# Direct recall
("What does Matt's friend Jordan do?", ["marine biologist"], "direct"),
("Where does Matt's sister Elena live?", ["portland"], "direct"),
("What is Matt's cousin Marcus's job?", ["architect"], "direct"),
("Where does Marcus live?", ["seattle"], "direct"),
("What does Priya do?", ["neurosurgeon"], "direct"),
("Where does Priya live?", ["chicago"], "direct"),
("Where does Matt work?", ["novamind"], "direct"),
("What is Matt's job title?", ["cto"], "direct"),
("What is Elena's cat's name?", ["mochi"], "direct"),
# Cross-entity confusion tests (THE critical ones)
("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"),
("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"),
("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"),
("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"),
("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"),
("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"),
]
def score_answer(answer, expected_keywords):
"""Score answer: 1.0 if all keywords found, partial otherwise."""
answer_lower = answer.lower()
hits = sum(1 for k in expected_keywords if k in answer_lower)
return hits / len(expected_keywords)
def run_session4():
print("=" * 60)
print("SESSION 4: Contrastive Disambiguation Test")
print(f"Time: {datetime.now().isoformat()}")
print("=" * 60)
# ── Step 1: Load model from checkpoint ──
print("\n[1/7] Loading model from checkpoint...")
mm = ModelManager(
model_path=CHECKPOINT,
checkpoint_path=CHECKPOINT,
)
mm.load()
quiz_gen = QuizGenerator(mm)
# ── Step 2: Load existing replay + quiz data ──
print("\n[2/7] Loading existing data...")
replay_path = os.path.join(LOG_DIR, "replay_buffer.json")
quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json")
all_training_data = []
quiz_pairs_log = []
if os.path.exists(replay_path):
with open(replay_path, 'r') as f:
all_training_data = json.load(f)
print(f" Replay buffer: {len(all_training_data)} examples")
if os.path.exists(quiz_log_path):
with open(quiz_log_path, 'r') as f:
quiz_pairs_log = json.load(f)
print(f" Quiz pairs: {len(quiz_pairs_log)}")
# ── Step 3: Cascade consolidation from teacher cache ──
print("\n[3/7] Cascade consolidation from teacher cache...")
teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt")
teacher_cache = None
if os.path.exists(teacher_cache_path):
teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False)
print(f" Teacher cache: {len(teacher_cache)} items")
loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS)
print(f" Consolidation loss: {loss:.4f}")
else:
print(" No teacher cache found β€” skipping consolidation")
# ── Step 4: Personality check ──
print("\n[4/7] Personality check...")
p_score = check_personality(mm)
# ── Step 5: Feed session 4 facts + generate quizzes ──
print("\n[5/7] Teaching session 4 facts...")
conversation_buffer = []
session_quizzes = []
for i, user_msg in enumerate(SESSION4_MESSAGES):
print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---")
print(f" Matt: {user_msg}")
conversation_buffer.append({"role": "user", "content": user_msg})
if len(conversation_buffer) > 20:
conversation_buffer = conversation_buffer[-20:]
# Generate response
response = mm.generate(conversation_buffer)
conversation_buffer.append({"role": "assistant", "content": response})
print(f" Claudia: {response[:150]}...")
# Generate quizzes (including contrastive!)
quizzes = quiz_gen.generate(user_msg, response)
print(f" Quizzes generated: {len(quizzes)}")
for qi, qp in enumerate(quizzes):
q = qp["messages"][0]["content"]
a = qp["messages"][1]["content"]
print(f" Q{qi+1}: {q}")
print(f" A{qi+1}: {a[:120]}")
# Store exchange
exchange = {"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": response},
]}
all_training_data.append(exchange)
all_training_data.extend(quizzes)
quiz_pairs_log.extend(quizzes)
session_quizzes.extend(quizzes)
# Absorb after each message (like the real pipeline)
import random
# New data + small replay of old
new_data = [exchange] + quizzes
old_sample = random.sample(all_training_data[:-len(new_data)],
min(8, len(all_training_data) - len(new_data))) \
if len(all_training_data) > len(new_data) + 8 else []
absorb_data = new_data + old_sample
t0 = time.time()
loss = mm.absorb(absorb_data)
elapsed = time.time() - t0
print(f" Absorbed {len(absorb_data)} examples in {elapsed:.1f}s, loss={loss:.4f}")
print(f"\n Session 4 total quizzes: {len(session_quizzes)}")
contrastive_count = sum(1 for q in session_quizzes
if "no." in q["messages"][1]["content"].lower()[:5])
print(f" Contrastive quizzes: {contrastive_count}")
# ── Step 6: Final absorption pass with all quiz pairs ──
print("\n[6/7] Final absorption pass (all session 4 quizzes)...")
import random
# Train on all session 4 quizzes + small old replay
old_sample = random.sample(all_training_data[:-len(session_quizzes)],
min(16, max(0, len(all_training_data) - len(session_quizzes)))) \
if len(all_training_data) > len(session_quizzes) + 16 else []
final_data = session_quizzes + old_sample
loss = mm.absorb(final_data)
print(f" Final absorption: {len(final_data)} examples, loss={loss:.4f}")
# ── Step 7: Recall test ──
print("\n[7/7] RECALL TEST")
print("=" * 60)
direct_correct = 0
direct_total = 0
contrastive_correct = 0
contrastive_total = 0
results = []
for question, keywords, qtype in RECALL_QUESTIONS:
answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
score = score_answer(answer, keywords)
passed = score >= 0.5
if qtype == "direct":
direct_total += 1
if passed:
direct_correct += 1
else:
contrastive_total += 1
if passed:
contrastive_correct += 1
status = "PASS" if passed else "FAIL"
print(f" [{status}] ({qtype}) {question}")
print(f" Expected: {keywords}")
print(f" Got: {answer[:150]}")
print(f" Score: {score:.2f}")
print()
results.append({
"question": question,
"expected": keywords,
"answer": answer,
"score": score,
"passed": passed,
"type": qtype,
})
print("=" * 60)
print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})")
print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})")
total = direct_correct + contrastive_correct
total_q = direct_total + contrastive_total
print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
print("=" * 60)
# ── Save checkpoint ──
print("\n--- Saving checkpoint ---")
version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}")
# Clean up old checkpoints first
for entry in os.listdir(CHECKPOINT_DIR):
full = os.path.join(CHECKPOINT_DIR, entry)
if os.path.isdir(full) and entry.startswith("claudia_"):
import shutil
size_gb = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, fns in os.walk(full) for f in fns
) / 1e9
print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
shutil.rmtree(full)
mm.merge_and_save(path)
# Save replay buffer
with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
with open(os.path.join(path, "replay_buffer.json"), 'w') as f:
json.dump(all_training_data, f)
# Save quiz pairs log
with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f:
json.dump(quiz_pairs_log, f)
print(f" Replay buffer: {len(all_training_data)} examples")
print(f" Quiz pairs log: {len(quiz_pairs_log)}")
# Cache teacher logits for next session
print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...")
new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log)
cache_path = os.path.join(path, "teacher_cache.pt")
torch.save(new_teacher_cache, cache_path)
size_mb = os.path.getsize(cache_path) / 1e6
print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)")
# Save results
results_data = {
"session": 4,
"checkpoint": path,
"direct_recall": f"{direct_correct}/{direct_total}",
"contrastive_recall": f"{contrastive_correct}/{contrastive_total}",
"total_recall": f"{total}/{total_q}",
"personality_score": p_score,
"session_quizzes": len(session_quizzes),
"contrastive_quizzes": contrastive_count,
"total_quiz_pairs": len(quiz_pairs_log),
"total_replay": len(all_training_data),
"timestamp": datetime.now().isoformat(),
"results": results,
}
results_path = os.path.join(LOG_DIR, f"session4_results.json")
with open(results_path, 'w') as f:
json.dump(results_data, f, indent=2)
print(f" Results saved: {results_path}")
print(f"\n Next run: --checkpoint {path}")
print("\nSESSION 4 COMPLETE.")
if __name__ == "__main__":
run_session4()