Multi-Manifold-Retrieval_POC / run_experiment.py
bragee's picture
Upload model checkpoints and code
b464490 verified
#!/usr/bin/env python3
"""End-to-end experiment: train → evaluate → spectral analysis → attack simulation."""
import os
import json
import random
import time
import yaml
import argparse
import numpy as np
import torch
from multi_manifold_retrieval.models.encoders import DualEncoder
from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator
from multi_manifold_retrieval.models.baseline import BaselineOperator
from multi_manifold_retrieval.training.train import train
from multi_manifold_retrieval.training.data import MSMARCOEvalDataset
from multi_manifold_retrieval.evaluation.spectral_analysis import run_spectral_analysis
from multi_manifold_retrieval.evaluation.retrieval_metrics import compute_retrieval_metrics
from multi_manifold_retrieval.evaluation.attack_simulation import (
run_attack_simulation,
select_domain_documents,
)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def main():
parser = argparse.ArgumentParser(description="Multi-Manifold Retrieval PoC Experiment")
parser.add_argument("--config", type=str, default="configs/default.yaml")
parser.add_argument("--skip-train", action="store_true", help="Skip training, load from checkpoint")
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_operator.pt")
parser.add_argument("--output", type=str, default="results.json")
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
set_seed(config["seed"])
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
results = {"config": config, "device": device}
# =========================================================================
# Phase 1: Training
# =========================================================================
print("\n" + "=" * 60)
print("PHASE 1: TRAINING")
print("=" * 60)
if args.skip_train and os.path.exists(args.checkpoint):
print(f"Loading encoder and operator from checkpoint: {args.checkpoint}")
encoder = DualEncoder(
model_name=config["encoder"]["model_name"],
max_seq_length=config["training"]["max_seq_length"],
freeze=config["encoder"]["freeze"],
)
cm_config = config["cross_manifold"]
operator = CrossManifoldOperator(
embedding_dim=encoder.embedding_dim,
num_heads=cm_config["num_heads"],
value_hidden_dim=cm_config["value_mlp_hidden"],
value_num_layers=cm_config["value_mlp_layers"],
dropout=cm_config["dropout"],
)
operator.load_state_dict(torch.load(args.checkpoint, map_location=device, weights_only=True))
operator.to(device)
else:
encoder, operator = train(config_path=args.config, device=device)
encoder.model.to(device)
operator.to(device)
baseline_operator = BaselineOperator().to(device)
# =========================================================================
# Phase 2: Evaluation Data Preparation
# =========================================================================
print("\n" + "=" * 60)
print("PHASE 2: EVALUATION DATA PREPARATION")
print("=" * 60)
eval_data = MSMARCOEvalDataset(
tokenizer=encoder.tokenizer,
max_queries=config["evaluation"]["max_eval_queries"],
max_seq_length=config["training"]["max_seq_length"],
seed=config["seed"],
)
# Encode all evaluation passages
print(f"Encoding {len(eval_data.all_passages)} passages...")
passage_embeddings = encoder.encode_documents(
eval_data.all_passages, batch_size=128, show_progress=True,
)
# Encode evaluation queries
print(f"Encoding {len(eval_data.queries)} queries...")
query_embeddings = encoder.encode_queries(
eval_data.queries, batch_size=128, show_progress=True,
)
# =========================================================================
# Phase 3: Retrieval Quality Evaluation
# =========================================================================
print("\n" + "=" * 60)
print("PHASE 3: RETRIEVAL QUALITY")
print("=" * 60)
print("\n--- Multi-Manifold Model ---")
metrics_mm = compute_retrieval_metrics(
query_embeddings=query_embeddings,
doc_embeddings=passage_embeddings,
operator=operator,
query_texts=eval_data.queries,
positive_passages=eval_data.positives,
all_passages=eval_data.all_passages,
passage_embeddings=passage_embeddings,
device=device,
)
results["retrieval_multi_manifold"] = metrics_mm
print("\n--- Baseline (Cosine Similarity) ---")
metrics_base = compute_retrieval_metrics(
query_embeddings=query_embeddings,
doc_embeddings=passage_embeddings,
operator=baseline_operator,
query_texts=eval_data.queries,
positive_passages=eval_data.positives,
all_passages=eval_data.all_passages,
passage_embeddings=passage_embeddings,
device=device,
)
results["retrieval_baseline"] = metrics_base
# Check: multi-manifold within 80% of baseline
if metrics_base["mrr@10"] > 0:
ratio = metrics_mm["mrr@10"] / metrics_base["mrr@10"]
print(f"\nMRR@10 ratio (mm/baseline): {ratio:.4f} "
f"({'PASS' if ratio >= 0.8 else 'BELOW TARGET'}, target >= 0.8)")
results["mrr_ratio"] = ratio
# =========================================================================
# Phase 4: Spectral Analysis
# =========================================================================
print("\n" + "=" * 60)
print("PHASE 4: SPECTRAL ANALYSIS")
print("=" * 60)
# Sample documents for spectral analysis
num_spectral_docs = min(config["spectral"]["num_documents"], len(eval_data.all_passages))
num_spectral_queries = min(config["spectral"]["num_queries"], len(eval_data.queries))
spectral_doc_indices = np.random.choice(
len(eval_data.all_passages), num_spectral_docs, replace=False
)
spectral_query_indices = np.random.choice(
len(eval_data.queries), num_spectral_queries, replace=False
)
spectral_doc_emb_np = passage_embeddings[spectral_doc_indices].cpu().numpy()
spectral_doc_emb_torch = passage_embeddings[spectral_doc_indices]
spectral_query_emb_torch = query_embeddings[spectral_query_indices]
spectral_results = run_spectral_analysis(
doc_embeddings_np=spectral_doc_emb_np,
doc_embeddings_torch=spectral_doc_emb_torch,
query_embeddings_torch=spectral_query_emb_torch,
operator=operator,
baseline_operator=baseline_operator,
device=device,
)
results["spectral"] = {
"multi_manifold": spectral_results["multi_manifold"],
"baseline": spectral_results["baseline"],
"num_documents": num_spectral_docs,
"num_queries": num_spectral_queries,
}
# =========================================================================
# Phase 5: Attack Simulation
# =========================================================================
print("\n" + "=" * 60)
print("PHASE 5: ATTACK SIMULATION")
print("=" * 60)
attack_config = config["attack"]
# Select target queries (medical domain)
target_queries = []
for q in eval_data.queries:
q_lower = q.lower()
if any(kw in q_lower for kw in attack_config["medical_keywords"]):
target_queries.append(q)
if len(target_queries) >= attack_config["num_target_queries"]:
break
if len(target_queries) < 10:
# Fall back: use random queries if not enough medical ones
print(f"Only found {len(target_queries)} medical queries; "
f"using random queries to reach {attack_config['num_target_queries']}.")
remaining = attack_config["num_target_queries"] - len(target_queries)
other_queries = [q for q in eval_data.queries if q not in target_queries]
target_queries.extend(random.sample(other_queries, min(remaining, len(other_queries))))
print(f"Using {len(target_queries)} target queries for attack simulation.")
attack_results = run_attack_simulation(
encoder=encoder,
operator=operator,
baseline_operator=baseline_operator,
passages=eval_data.all_passages,
passage_embeddings_torch=passage_embeddings,
target_query_texts=target_queries,
medical_keywords=attack_config["medical_keywords"],
top_k=attack_config["top_k"],
device=device,
)
results["attack"] = attack_results
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 60)
print("EXPERIMENT SUMMARY")
print("=" * 60)
print(f"\n1. Retrieval Quality:")
print(f" Baseline MRR@10: {metrics_base['mrr@10']:.4f}")
print(f" Multi-Manifold MRR@10: {metrics_mm['mrr@10']:.4f}")
if metrics_base["mrr@10"] > 0:
print(f" Ratio: {metrics_mm['mrr@10']/metrics_base['mrr@10']:.4f}")
print(f"\n2. Spectral Analysis:")
print(f" Baseline δ: {spectral_results['baseline']['spectral_discrepancy']:.4f}")
print(f" Multi-Manifold δ: {spectral_results['multi_manifold']['spectral_discrepancy']:.4f}")
print(f" Baseline cos(θ): {spectral_results['baseline']['fiedler_alignment']:.4f}")
print(f" Multi-Manifold cos(θ): {spectral_results['multi_manifold']['fiedler_alignment']:.4f}")
if "error" not in attack_results:
print(f"\n3. Attack Simulation:")
print(f" Baseline ASR@{attack_config['top_k']}: {attack_results['baseline_asr']:.4f}")
print(f" Multi-Manifold ASR@{attack_config['top_k']}: {attack_results['multi_manifold_asr']:.4f}")
# Save results (exclude numpy arrays)
save_results = {k: v for k, v in results.items()
if k not in ("L_D", "L_R_mm", "L_R_base")}
with open(args.output, "w") as f:
json.dump(save_results, f, indent=2, default=str)
print(f"\nResults saved to {args.output}")
if __name__ == "__main__":
main()