#!/usr/bin/env python3 """End-to-end experiment: train → evaluate → spectral analysis → attack simulation.""" import os import json import random import time import yaml import argparse import numpy as np import torch from multi_manifold_retrieval.models.encoders import DualEncoder from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator from multi_manifold_retrieval.models.baseline import BaselineOperator from multi_manifold_retrieval.training.train import train from multi_manifold_retrieval.training.data import MSMARCOEvalDataset from multi_manifold_retrieval.evaluation.spectral_analysis import run_spectral_analysis from multi_manifold_retrieval.evaluation.retrieval_metrics import compute_retrieval_metrics from multi_manifold_retrieval.evaluation.attack_simulation import ( run_attack_simulation, select_domain_documents, ) def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def main(): parser = argparse.ArgumentParser(description="Multi-Manifold Retrieval PoC Experiment") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--skip-train", action="store_true", help="Skip training, load from checkpoint") parser.add_argument("--checkpoint", type=str, default="checkpoints/best_operator.pt") parser.add_argument("--output", type=str, default="results.json") args = parser.parse_args() with open(args.config) as f: config = yaml.safe_load(f) set_seed(config["seed"]) device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {device}") results = {"config": config, "device": device} # ========================================================================= # Phase 1: Training # ========================================================================= print("\n" + "=" * 60) print("PHASE 1: TRAINING") print("=" * 60) if args.skip_train and os.path.exists(args.checkpoint): print(f"Loading encoder and operator from checkpoint: {args.checkpoint}") encoder = DualEncoder( model_name=config["encoder"]["model_name"], max_seq_length=config["training"]["max_seq_length"], freeze=config["encoder"]["freeze"], ) cm_config = config["cross_manifold"] operator = CrossManifoldOperator( embedding_dim=encoder.embedding_dim, num_heads=cm_config["num_heads"], value_hidden_dim=cm_config["value_mlp_hidden"], value_num_layers=cm_config["value_mlp_layers"], dropout=cm_config["dropout"], ) operator.load_state_dict(torch.load(args.checkpoint, map_location=device, weights_only=True)) operator.to(device) else: encoder, operator = train(config_path=args.config, device=device) encoder.model.to(device) operator.to(device) baseline_operator = BaselineOperator().to(device) # ========================================================================= # Phase 2: Evaluation Data Preparation # ========================================================================= print("\n" + "=" * 60) print("PHASE 2: EVALUATION DATA PREPARATION") print("=" * 60) eval_data = MSMARCOEvalDataset( tokenizer=encoder.tokenizer, max_queries=config["evaluation"]["max_eval_queries"], max_seq_length=config["training"]["max_seq_length"], seed=config["seed"], ) # Encode all evaluation passages print(f"Encoding {len(eval_data.all_passages)} passages...") passage_embeddings = encoder.encode_documents( eval_data.all_passages, batch_size=128, show_progress=True, ) # Encode evaluation queries print(f"Encoding {len(eval_data.queries)} queries...") query_embeddings = encoder.encode_queries( eval_data.queries, batch_size=128, show_progress=True, ) # ========================================================================= # Phase 3: Retrieval Quality Evaluation # ========================================================================= print("\n" + "=" * 60) print("PHASE 3: RETRIEVAL QUALITY") print("=" * 60) print("\n--- Multi-Manifold Model ---") metrics_mm = compute_retrieval_metrics( query_embeddings=query_embeddings, doc_embeddings=passage_embeddings, operator=operator, query_texts=eval_data.queries, positive_passages=eval_data.positives, all_passages=eval_data.all_passages, passage_embeddings=passage_embeddings, device=device, ) results["retrieval_multi_manifold"] = metrics_mm print("\n--- Baseline (Cosine Similarity) ---") metrics_base = compute_retrieval_metrics( query_embeddings=query_embeddings, doc_embeddings=passage_embeddings, operator=baseline_operator, query_texts=eval_data.queries, positive_passages=eval_data.positives, all_passages=eval_data.all_passages, passage_embeddings=passage_embeddings, device=device, ) results["retrieval_baseline"] = metrics_base # Check: multi-manifold within 80% of baseline if metrics_base["mrr@10"] > 0: ratio = metrics_mm["mrr@10"] / metrics_base["mrr@10"] print(f"\nMRR@10 ratio (mm/baseline): {ratio:.4f} " f"({'PASS' if ratio >= 0.8 else 'BELOW TARGET'}, target >= 0.8)") results["mrr_ratio"] = ratio # ========================================================================= # Phase 4: Spectral Analysis # ========================================================================= print("\n" + "=" * 60) print("PHASE 4: SPECTRAL ANALYSIS") print("=" * 60) # Sample documents for spectral analysis num_spectral_docs = min(config["spectral"]["num_documents"], len(eval_data.all_passages)) num_spectral_queries = min(config["spectral"]["num_queries"], len(eval_data.queries)) spectral_doc_indices = np.random.choice( len(eval_data.all_passages), num_spectral_docs, replace=False ) spectral_query_indices = np.random.choice( len(eval_data.queries), num_spectral_queries, replace=False ) spectral_doc_emb_np = passage_embeddings[spectral_doc_indices].cpu().numpy() spectral_doc_emb_torch = passage_embeddings[spectral_doc_indices] spectral_query_emb_torch = query_embeddings[spectral_query_indices] spectral_results = run_spectral_analysis( doc_embeddings_np=spectral_doc_emb_np, doc_embeddings_torch=spectral_doc_emb_torch, query_embeddings_torch=spectral_query_emb_torch, operator=operator, baseline_operator=baseline_operator, device=device, ) results["spectral"] = { "multi_manifold": spectral_results["multi_manifold"], "baseline": spectral_results["baseline"], "num_documents": num_spectral_docs, "num_queries": num_spectral_queries, } # ========================================================================= # Phase 5: Attack Simulation # ========================================================================= print("\n" + "=" * 60) print("PHASE 5: ATTACK SIMULATION") print("=" * 60) attack_config = config["attack"] # Select target queries (medical domain) target_queries = [] for q in eval_data.queries: q_lower = q.lower() if any(kw in q_lower for kw in attack_config["medical_keywords"]): target_queries.append(q) if len(target_queries) >= attack_config["num_target_queries"]: break if len(target_queries) < 10: # Fall back: use random queries if not enough medical ones print(f"Only found {len(target_queries)} medical queries; " f"using random queries to reach {attack_config['num_target_queries']}.") remaining = attack_config["num_target_queries"] - len(target_queries) other_queries = [q for q in eval_data.queries if q not in target_queries] target_queries.extend(random.sample(other_queries, min(remaining, len(other_queries)))) print(f"Using {len(target_queries)} target queries for attack simulation.") attack_results = run_attack_simulation( encoder=encoder, operator=operator, baseline_operator=baseline_operator, passages=eval_data.all_passages, passage_embeddings_torch=passage_embeddings, target_query_texts=target_queries, medical_keywords=attack_config["medical_keywords"], top_k=attack_config["top_k"], device=device, ) results["attack"] = attack_results # ========================================================================= # Summary # ========================================================================= print("\n" + "=" * 60) print("EXPERIMENT SUMMARY") print("=" * 60) print(f"\n1. Retrieval Quality:") print(f" Baseline MRR@10: {metrics_base['mrr@10']:.4f}") print(f" Multi-Manifold MRR@10: {metrics_mm['mrr@10']:.4f}") if metrics_base["mrr@10"] > 0: print(f" Ratio: {metrics_mm['mrr@10']/metrics_base['mrr@10']:.4f}") print(f"\n2. Spectral Analysis:") print(f" Baseline δ: {spectral_results['baseline']['spectral_discrepancy']:.4f}") print(f" Multi-Manifold δ: {spectral_results['multi_manifold']['spectral_discrepancy']:.4f}") print(f" Baseline cos(θ): {spectral_results['baseline']['fiedler_alignment']:.4f}") print(f" Multi-Manifold cos(θ): {spectral_results['multi_manifold']['fiedler_alignment']:.4f}") if "error" not in attack_results: print(f"\n3. Attack Simulation:") print(f" Baseline ASR@{attack_config['top_k']}: {attack_results['baseline_asr']:.4f}") print(f" Multi-Manifold ASR@{attack_config['top_k']}: {attack_results['multi_manifold_asr']:.4f}") # Save results (exclude numpy arrays) save_results = {k: v for k, v in results.items() if k not in ("L_D", "L_R_mm", "L_R_base")} with open(args.output, "w") as f: json.dump(save_results, f, indent=2, default=str) print(f"\nResults saved to {args.output}") if __name__ == "__main__": main()