| """
|
| Session 4e — Fixed Contrastive Focus + Better Phase 2 Targeting
|
| ================================================================
|
| Session 4d scored 11/15 (same as 4b). Two bugs fixed:
|
|
|
| 1. CONTRASTIVE DUPLICATION FIX: _generate_contrastive_quizzes now only generates
|
| pairs involving at least one NEW entity. Previously re-generated all pairs
|
| between existing entities every message, making 50% of quizzes contrastive.
|
|
|
| 2. PHASE 2 TARGETING FIX: When checking which contrastive pairs are relevant to
|
| confused entities, now checks BOTH question AND answer text (not just question).
|
| "Is Jordan an architect?" is relevant to Marcus confusion even though "Marcus"
|
| only appears in the answer.
|
|
|
| Expected improvement: More positive quizzes (better direct recall) + more
|
| effective contrastive correction (Phase 2 doesn't skip relevant pairs).
|
|
|
| Run: python3 -u test_session4e.py 2>&1 | tee /workspace/logs/distill2_s4e.log
|
| """
|
|
|
| import json
|
| import os
|
| import random
|
| import sys
|
| import time
|
| import torch
|
| from datetime import datetime
|
|
|
| sys.path.insert(0, "/workspace")
|
| from persistent_absorber import (
|
| ModelManager, QuizGenerator, check_personality,
|
| CONSOLIDATION_EPOCHS, MAX_TEACHER_CACHE
|
| )
|
|
|
| CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0327"
|
| LOG_DIR = "/workspace/logs"
|
| CHECKPOINT_DIR = "/workspace/checkpoints"
|
|
|
| SESSION4_MESSAGES = [
|
| "My friend Jordan is a marine biologist in San Diego. My sister Elena is a veterinarian in Portland.",
|
| "My cousin Marcus is an architect in Seattle. He designed the new library there.",
|
| "I work at a startup called NovaMind. I'm the CTO. We're based in Austin.",
|
| "My best friend Priya is a neurosurgeon in Chicago. She went to Johns Hopkins.",
|
| "Elena actually just got a new cat named Mochi. And Jordan got his dive certification renewed last month.",
|
| ]
|
|
|
| RECALL_QUESTIONS = [
|
| ("What does Matt's friend Jordan do?", ["marine biologist"], "direct"),
|
| ("Where does Matt's sister Elena live?", ["portland"], "direct"),
|
| ("What is Matt's cousin Marcus's job?", ["architect"], "direct"),
|
| ("Where does Marcus live?", ["seattle"], "direct"),
|
| ("What does Priya do?", ["neurosurgeon"], "direct"),
|
| ("Where does Priya live?", ["chicago"], "direct"),
|
| ("Where does Matt work?", ["novamind"], "direct"),
|
| ("What is Matt's job title?", ["cto"], "direct"),
|
| ("What is Elena's cat's name?", ["mochi"], "direct"),
|
| ("Is Elena a marine biologist?", ["no", "veterinarian", "jordan"], "contrastive"),
|
| ("Does Jordan live in Portland?", ["no", "san diego", "elena"], "contrastive"),
|
| ("Is Marcus a neurosurgeon?", ["no", "architect", "priya"], "contrastive"),
|
| ("Does Priya live in Seattle?", ["no", "chicago", "marcus"], "contrastive"),
|
| ("Is Matt an architect?", ["no", "cto", "marcus"], "contrastive"),
|
| ("Does Jordan work at NovaMind?", ["no", "marine biologist", "matt"], "contrastive"),
|
| ]
|
|
|
|
|
| def score_answer(answer, expected_keywords):
|
| answer_lower = answer.lower()
|
| hits = sum(1 for k in expected_keywords if k in answer_lower)
|
| return hits / len(expected_keywords)
|
|
|
|
|
| def quick_verify_entities(mm, known_entities):
|
| """Quick verification: test model on key entity facts.
|
| Returns set of entity names that the model gets WRONG."""
|
| confused = set()
|
| for name, info in known_entities.items():
|
| if info.get("job"):
|
| q = f"What does Matt's {info['relationship']} {name} do?"
|
| ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["job"].lower() not in ans.lower():
|
| confused.add(name)
|
| print(f" CONFUSED: {name} job — expected '{info['job']}', got: {ans[:80]}")
|
| if info.get("city"):
|
| q = f"Where does {name} live?"
|
| ans = mm.generate([{"role": "user", "content": q}], max_new_tokens=100)
|
| if info["city"].lower() not in ans.lower():
|
| confused.add(name)
|
| print(f" CONFUSED: {name} city — expected '{info['city']}', got: {ans[:80]}")
|
| return confused
|
|
|
|
|
| def run_session4e():
|
| print("=" * 60)
|
| print("SESSION 4e: Fixed Contrastive Focus + Phase 2 Targeting")
|
| print(f"Time: {datetime.now().isoformat()}")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[1/8] Loading model from checkpoint...")
|
| mm = ModelManager(
|
| model_path=CHECKPOINT,
|
| checkpoint_path=CHECKPOINT,
|
| )
|
| mm.load()
|
|
|
| quiz_gen = QuizGenerator(mm)
|
|
|
|
|
| print("\n[2/8] Loading existing data...")
|
| replay_path = os.path.join(LOG_DIR, "replay_buffer.json")
|
| quiz_log_path = os.path.join(LOG_DIR, "quiz_pairs_log.json")
|
|
|
| all_training_data = []
|
| quiz_pairs_log = []
|
|
|
| if os.path.exists(replay_path):
|
| with open(replay_path, 'r') as f:
|
| all_training_data = json.load(f)
|
| print(f" Replay buffer: {len(all_training_data)} examples")
|
|
|
| if os.path.exists(quiz_log_path):
|
| with open(quiz_log_path, 'r') as f:
|
| quiz_pairs_log = json.load(f)
|
| print(f" Quiz pairs: {len(quiz_pairs_log)}")
|
|
|
|
|
| print("\n[3/8] Cascade consolidation from teacher cache...")
|
| teacher_cache_path = os.path.join(CHECKPOINT, "teacher_cache.pt")
|
| if os.path.exists(teacher_cache_path):
|
| teacher_cache = torch.load(teacher_cache_path, map_location="cpu", weights_only=False)
|
| print(f" Teacher cache: {len(teacher_cache)} items")
|
| loss = mm.distill(teacher_cache, epochs=CONSOLIDATION_EPOCHS)
|
| print(f" Consolidation loss: {loss:.4f}")
|
| else:
|
| print(" No teacher cache found — skipping")
|
|
|
|
|
| print("\n[4/8] Personality check...")
|
| p_score = check_personality(mm)
|
|
|
|
|
| print("\n[5/8] Teaching facts (fixed contrastive focus + phase 2 targeting)...")
|
| conversation_buffer = []
|
| session_quizzes = []
|
| session_positive = []
|
| session_contrastive = []
|
| contrastive_count = 0
|
| positive_count = 0
|
| total_quiz_count = 0
|
|
|
| for i, user_msg in enumerate(SESSION4_MESSAGES):
|
| print(f"\n --- Message {i+1}/{len(SESSION4_MESSAGES)} ---")
|
| print(f" Matt: {user_msg}")
|
|
|
| conversation_buffer.append({"role": "user", "content": user_msg})
|
| if len(conversation_buffer) > 20:
|
| conversation_buffer = conversation_buffer[-20:]
|
|
|
| response = mm.generate(conversation_buffer)
|
| conversation_buffer.append({"role": "assistant", "content": response})
|
| print(f" Claudia: {response[:150]}...")
|
|
|
|
|
| quizzes = quiz_gen.generate(user_msg, response)
|
| total_quiz_count += len(quizzes)
|
| print(f" Quizzes generated: {len(quizzes)} (session total: {total_quiz_count})")
|
| print(f" Known entities: {list(quiz_gen.known_entities.keys())}")
|
|
|
|
|
| positive_batch = []
|
| contrastive_batch = []
|
| for qp in quizzes:
|
| a = qp["messages"][1]["content"]
|
| is_contrastive = a.lower().startswith("no.")
|
| if is_contrastive:
|
| contrastive_batch.append(qp)
|
| contrastive_count += 1
|
| else:
|
| positive_batch.append(qp)
|
| positive_count += 1
|
|
|
|
|
| print(f" Positive: {len(positive_batch)}, Contrastive: {len(contrastive_batch)}")
|
| for qi, qp in enumerate(quizzes[:8]):
|
| q = qp["messages"][0]["content"]
|
| a = qp["messages"][1]["content"]
|
| tag = " [C]" if a.lower().startswith("no.") else ""
|
| print(f" Q{qi+1}: {q}{tag}")
|
| print(f" A{qi+1}: {a[:120]}")
|
|
|
| exchange = {"messages": [
|
| {"role": "user", "content": user_msg},
|
| {"role": "assistant", "content": response},
|
| ]}
|
|
|
| all_training_data.append(exchange)
|
| all_training_data.extend(quizzes)
|
| quiz_pairs_log.extend(quizzes)
|
| session_quizzes.extend(quizzes)
|
| session_positive.extend(positive_batch)
|
| session_contrastive.extend(contrastive_batch)
|
|
|
|
|
|
|
| phase1_data = [exchange] + positive_batch
|
| old_sample = random.sample(
|
| all_training_data[:-len(quizzes)-1],
|
| min(6, max(0, len(all_training_data) - len(quizzes) - 1))
|
| ) if len(all_training_data) > len(quizzes) + 7 else []
|
| phase1_data.extend(old_sample)
|
|
|
| if quiz_gen.known_entities:
|
| phase1_data = ModelManager.cluster_by_entity(
|
| phase1_data, list(quiz_gen.known_entities.keys()))
|
|
|
| t0 = time.time()
|
| loss1 = mm.absorb(phase1_data)
|
| elapsed1 = time.time() - t0
|
| print(f" Phase 1 (positive): {len(phase1_data)} examples, {elapsed1:.1f}s, loss={loss1:.4f}")
|
|
|
|
|
| if contrastive_batch:
|
| print(f" Phase 2: Verifying entities before contrastive correction...")
|
| confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
|
|
| if confused:
|
|
|
|
|
| targeted = []
|
| for qp in contrastive_batch:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if any(name.lower() in full_text for name in confused):
|
| targeted.append(qp)
|
| if targeted:
|
| t0 = time.time()
|
| loss2 = mm.absorb(targeted)
|
| elapsed2 = time.time() - t0
|
| print(f" Phase 2 (contrastive): {len(targeted)}/{len(contrastive_batch)} targeted, "
|
| f"{elapsed2:.1f}s, loss={loss2:.4f}")
|
| else:
|
| print(f" Phase 2: No contrastive pairs matched confused entities")
|
| else:
|
| print(f" Phase 2: All entities correct — skipping contrastive!")
|
|
|
| print(f"\n Session 4e totals:")
|
| print(f" Total quizzes: {total_quiz_count}")
|
| print(f" Positive: {positive_count}")
|
| print(f" Contrastive: {contrastive_count}")
|
| print(f" Ratio: {positive_count/(positive_count+contrastive_count):.0%} positive")
|
| print(f" Known entities: {list(quiz_gen.known_entities.keys())}")
|
|
|
|
|
| print("\n[6/8] Final two-phase absorption pass...")
|
| old_sample = random.sample(
|
| all_training_data[:-len(session_quizzes)],
|
| min(12, max(0, len(all_training_data) - len(session_quizzes)))
|
| ) if len(all_training_data) > len(session_quizzes) + 12 else []
|
|
|
| final_positive = session_positive + old_sample
|
| if quiz_gen.known_entities:
|
| final_positive = ModelManager.cluster_by_entity(
|
| final_positive, list(quiz_gen.known_entities.keys()))
|
| loss_fp = mm.absorb(final_positive)
|
| print(f" Final Phase 1 (positive): {len(final_positive)} examples, loss={loss_fp:.4f}")
|
|
|
|
|
| print(f" Final Phase 2: Verifying entities...")
|
| confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
| if confused and session_contrastive:
|
| targeted = []
|
| for qp in session_contrastive:
|
| full_text = (qp["messages"][0]["content"] + " " +
|
| qp["messages"][1]["content"]).lower()
|
| if any(name.lower() in full_text for name in confused):
|
| targeted.append(qp)
|
| if targeted:
|
| loss_fc = mm.absorb(targeted)
|
| print(f" Final Phase 2 (contrastive): {len(targeted)} targeted, loss={loss_fc:.4f}")
|
| else:
|
| print(f" Final Phase 2: No matching contrastive pairs")
|
| else:
|
| status = "no confusion detected" if not confused else "no contrastive data"
|
| print(f" Final Phase 2: Skipped ({status})")
|
|
|
|
|
| print("\n[7/8] RECALL TEST")
|
| print("=" * 60)
|
|
|
| direct_correct = 0
|
| direct_total = 0
|
| contrastive_correct = 0
|
| contrastive_total = 0
|
| results = []
|
|
|
| for question, keywords, qtype in RECALL_QUESTIONS:
|
| answer = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
|
| score = score_answer(answer, keywords)
|
| passed = score >= 0.5
|
|
|
| if qtype == "direct":
|
| direct_total += 1
|
| if passed:
|
| direct_correct += 1
|
| else:
|
| contrastive_total += 1
|
| if passed:
|
| contrastive_correct += 1
|
|
|
| status = "PASS" if passed else "FAIL"
|
| print(f" [{status}] ({qtype}) {question}")
|
| print(f" Expected: {keywords}")
|
| print(f" Got: {answer[:150]}")
|
| print(f" Score: {score:.2f}")
|
| print()
|
|
|
| results.append({
|
| "question": question,
|
| "expected": keywords,
|
| "answer": answer,
|
| "score": score,
|
| "passed": passed,
|
| "type": qtype,
|
| })
|
|
|
| print("=" * 60)
|
| print(f"DIRECT RECALL: {direct_correct}/{direct_total} ({direct_correct/direct_total:.0%})")
|
| print(f"CONTRASTIVE RECALL: {contrastive_correct}/{contrastive_total} ({contrastive_correct/contrastive_total:.0%})")
|
| total = direct_correct + contrastive_correct
|
| total_q = direct_total + contrastive_total
|
| print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
|
| comparison = "IMPROVED" if total > 11 else ("SAME" if total == 11 else "REGRESSED")
|
| print(f"vs SESSION 4b/4d: {comparison} (was 11/15 = 73%)")
|
| print(f"DIRECT vs 4d: {'UP' if direct_correct > 5 else ('SAME' if direct_correct == 5 else 'DOWN')} (was 5/9)")
|
| print("=" * 60)
|
|
|
|
|
| print("\n[8/8] Saving checkpoint...")
|
| version = f"session_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
| path = os.path.join(CHECKPOINT_DIR, f"claudia_{version}")
|
|
|
| for entry in os.listdir(CHECKPOINT_DIR):
|
| full = os.path.join(CHECKPOINT_DIR, entry)
|
| if os.path.isdir(full) and entry.startswith("claudia_"):
|
| import shutil
|
| size_gb = sum(
|
| os.path.getsize(os.path.join(dp, f))
|
| for dp, _, fns in os.walk(full) for f in fns
|
| ) / 1e9
|
| print(f" Removing old checkpoint: {entry} ({size_gb:.1f} GB)")
|
| shutil.rmtree(full)
|
|
|
| mm.merge_and_save(path)
|
|
|
| with open(os.path.join(LOG_DIR, "replay_buffer.json"), 'w') as f:
|
| json.dump(all_training_data, f)
|
| with open(os.path.join(path, "replay_buffer.json"), 'w') as f:
|
| json.dump(all_training_data, f)
|
| with open(os.path.join(LOG_DIR, "quiz_pairs_log.json"), 'w') as f:
|
| json.dump(quiz_pairs_log, f)
|
|
|
| print(f" Replay buffer: {len(all_training_data)} examples")
|
| print(f" Quiz pairs log: {len(quiz_pairs_log)}")
|
|
|
| print(f" Caching teacher logits ({len(quiz_pairs_log)} quiz pairs)...")
|
| new_teacher_cache = mm.cache_teacher_logits(quiz_pairs_log)
|
| cache_path = os.path.join(path, "teacher_cache.pt")
|
| torch.save(new_teacher_cache, cache_path)
|
| size_mb = os.path.getsize(cache_path) / 1e6
|
| print(f" Teacher cache saved ({len(new_teacher_cache)} items, {size_mb:.1f} MB)")
|
|
|
| results_data = {
|
| "session": "4e",
|
| "fixes": [
|
| "contrastive focus: only generate pairs involving new entities",
|
| "phase 2 targeting: check both Q and A text for confused entities",
|
| ],
|
| "checkpoint": path,
|
| "direct_recall": f"{direct_correct}/{direct_total}",
|
| "contrastive_recall": f"{contrastive_correct}/{contrastive_total}",
|
| "total_recall": f"{total}/{total_q}",
|
| "vs_session4b": comparison,
|
| "personality_score": p_score,
|
| "total_quizzes": total_quiz_count,
|
| "positive_quizzes": positive_count,
|
| "contrastive_quizzes": contrastive_count,
|
| "total_quiz_pairs": len(quiz_pairs_log),
|
| "total_replay": len(all_training_data),
|
| "known_entities": {k: v for k, v in quiz_gen.known_entities.items()},
|
| "timestamp": datetime.now().isoformat(),
|
| "results": results,
|
| }
|
| results_path = os.path.join(LOG_DIR, "session4e_results.json")
|
| with open(results_path, 'w') as f:
|
| json.dump(results_data, f, indent=2)
|
| print(f" Results saved: {results_path}")
|
| print(f"\n Next run: --checkpoint {path}")
|
| print("\nSESSION 4e COMPLETE.")
|
|
|
|
|
| if __name__ == "__main__":
|
| run_session4e()
|
|
|