| |
| """End-to-end experiment: train → evaluate → spectral analysis → attack simulation.""" |
|
|
| import os |
| import json |
| import random |
| import time |
| import yaml |
| import argparse |
|
|
| import numpy as np |
| import torch |
|
|
| from multi_manifold_retrieval.models.encoders import DualEncoder |
| from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator |
| from multi_manifold_retrieval.models.baseline import BaselineOperator |
| from multi_manifold_retrieval.training.train import train |
| from multi_manifold_retrieval.training.data import MSMARCOEvalDataset |
| from multi_manifold_retrieval.evaluation.spectral_analysis import run_spectral_analysis |
| from multi_manifold_retrieval.evaluation.retrieval_metrics import compute_retrieval_metrics |
| from multi_manifold_retrieval.evaluation.attack_simulation import ( |
| run_attack_simulation, |
| select_domain_documents, |
| ) |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Multi-Manifold Retrieval PoC Experiment") |
| parser.add_argument("--config", type=str, default="configs/default.yaml") |
| parser.add_argument("--skip-train", action="store_true", help="Skip training, load from checkpoint") |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/best_operator.pt") |
| parser.add_argument("--output", type=str, default="results.json") |
| args = parser.parse_args() |
|
|
| with open(args.config) as f: |
| config = yaml.safe_load(f) |
|
|
| set_seed(config["seed"]) |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| results = {"config": config, "device": device} |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 1: TRAINING") |
| print("=" * 60) |
|
|
| if args.skip_train and os.path.exists(args.checkpoint): |
| print(f"Loading encoder and operator from checkpoint: {args.checkpoint}") |
| encoder = DualEncoder( |
| model_name=config["encoder"]["model_name"], |
| max_seq_length=config["training"]["max_seq_length"], |
| freeze=config["encoder"]["freeze"], |
| ) |
| cm_config = config["cross_manifold"] |
| operator = CrossManifoldOperator( |
| embedding_dim=encoder.embedding_dim, |
| num_heads=cm_config["num_heads"], |
| value_hidden_dim=cm_config["value_mlp_hidden"], |
| value_num_layers=cm_config["value_mlp_layers"], |
| dropout=cm_config["dropout"], |
| ) |
| operator.load_state_dict(torch.load(args.checkpoint, map_location=device, weights_only=True)) |
| operator.to(device) |
| else: |
| encoder, operator = train(config_path=args.config, device=device) |
|
|
| encoder.model.to(device) |
| operator.to(device) |
| baseline_operator = BaselineOperator().to(device) |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 2: EVALUATION DATA PREPARATION") |
| print("=" * 60) |
|
|
| eval_data = MSMARCOEvalDataset( |
| tokenizer=encoder.tokenizer, |
| max_queries=config["evaluation"]["max_eval_queries"], |
| max_seq_length=config["training"]["max_seq_length"], |
| seed=config["seed"], |
| ) |
|
|
| |
| print(f"Encoding {len(eval_data.all_passages)} passages...") |
| passage_embeddings = encoder.encode_documents( |
| eval_data.all_passages, batch_size=128, show_progress=True, |
| ) |
|
|
| |
| print(f"Encoding {len(eval_data.queries)} queries...") |
| query_embeddings = encoder.encode_queries( |
| eval_data.queries, batch_size=128, show_progress=True, |
| ) |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 3: RETRIEVAL QUALITY") |
| print("=" * 60) |
|
|
| print("\n--- Multi-Manifold Model ---") |
| metrics_mm = compute_retrieval_metrics( |
| query_embeddings=query_embeddings, |
| doc_embeddings=passage_embeddings, |
| operator=operator, |
| query_texts=eval_data.queries, |
| positive_passages=eval_data.positives, |
| all_passages=eval_data.all_passages, |
| passage_embeddings=passage_embeddings, |
| device=device, |
| ) |
| results["retrieval_multi_manifold"] = metrics_mm |
|
|
| print("\n--- Baseline (Cosine Similarity) ---") |
| metrics_base = compute_retrieval_metrics( |
| query_embeddings=query_embeddings, |
| doc_embeddings=passage_embeddings, |
| operator=baseline_operator, |
| query_texts=eval_data.queries, |
| positive_passages=eval_data.positives, |
| all_passages=eval_data.all_passages, |
| passage_embeddings=passage_embeddings, |
| device=device, |
| ) |
| results["retrieval_baseline"] = metrics_base |
|
|
| |
| if metrics_base["mrr@10"] > 0: |
| ratio = metrics_mm["mrr@10"] / metrics_base["mrr@10"] |
| print(f"\nMRR@10 ratio (mm/baseline): {ratio:.4f} " |
| f"({'PASS' if ratio >= 0.8 else 'BELOW TARGET'}, target >= 0.8)") |
| results["mrr_ratio"] = ratio |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 4: SPECTRAL ANALYSIS") |
| print("=" * 60) |
|
|
| |
| num_spectral_docs = min(config["spectral"]["num_documents"], len(eval_data.all_passages)) |
| num_spectral_queries = min(config["spectral"]["num_queries"], len(eval_data.queries)) |
|
|
| spectral_doc_indices = np.random.choice( |
| len(eval_data.all_passages), num_spectral_docs, replace=False |
| ) |
| spectral_query_indices = np.random.choice( |
| len(eval_data.queries), num_spectral_queries, replace=False |
| ) |
|
|
| spectral_doc_emb_np = passage_embeddings[spectral_doc_indices].cpu().numpy() |
| spectral_doc_emb_torch = passage_embeddings[spectral_doc_indices] |
| spectral_query_emb_torch = query_embeddings[spectral_query_indices] |
|
|
| spectral_results = run_spectral_analysis( |
| doc_embeddings_np=spectral_doc_emb_np, |
| doc_embeddings_torch=spectral_doc_emb_torch, |
| query_embeddings_torch=spectral_query_emb_torch, |
| operator=operator, |
| baseline_operator=baseline_operator, |
| device=device, |
| ) |
|
|
| results["spectral"] = { |
| "multi_manifold": spectral_results["multi_manifold"], |
| "baseline": spectral_results["baseline"], |
| "num_documents": num_spectral_docs, |
| "num_queries": num_spectral_queries, |
| } |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 5: ATTACK SIMULATION") |
| print("=" * 60) |
|
|
| attack_config = config["attack"] |
|
|
| |
| target_queries = [] |
| for q in eval_data.queries: |
| q_lower = q.lower() |
| if any(kw in q_lower for kw in attack_config["medical_keywords"]): |
| target_queries.append(q) |
| if len(target_queries) >= attack_config["num_target_queries"]: |
| break |
|
|
| if len(target_queries) < 10: |
| |
| print(f"Only found {len(target_queries)} medical queries; " |
| f"using random queries to reach {attack_config['num_target_queries']}.") |
| remaining = attack_config["num_target_queries"] - len(target_queries) |
| other_queries = [q for q in eval_data.queries if q not in target_queries] |
| target_queries.extend(random.sample(other_queries, min(remaining, len(other_queries)))) |
|
|
| print(f"Using {len(target_queries)} target queries for attack simulation.") |
|
|
| attack_results = run_attack_simulation( |
| encoder=encoder, |
| operator=operator, |
| baseline_operator=baseline_operator, |
| passages=eval_data.all_passages, |
| passage_embeddings_torch=passage_embeddings, |
| target_query_texts=target_queries, |
| medical_keywords=attack_config["medical_keywords"], |
| top_k=attack_config["top_k"], |
| device=device, |
| ) |
| results["attack"] = attack_results |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("EXPERIMENT SUMMARY") |
| print("=" * 60) |
|
|
| print(f"\n1. Retrieval Quality:") |
| print(f" Baseline MRR@10: {metrics_base['mrr@10']:.4f}") |
| print(f" Multi-Manifold MRR@10: {metrics_mm['mrr@10']:.4f}") |
| if metrics_base["mrr@10"] > 0: |
| print(f" Ratio: {metrics_mm['mrr@10']/metrics_base['mrr@10']:.4f}") |
|
|
| print(f"\n2. Spectral Analysis:") |
| print(f" Baseline δ: {spectral_results['baseline']['spectral_discrepancy']:.4f}") |
| print(f" Multi-Manifold δ: {spectral_results['multi_manifold']['spectral_discrepancy']:.4f}") |
| print(f" Baseline cos(θ): {spectral_results['baseline']['fiedler_alignment']:.4f}") |
| print(f" Multi-Manifold cos(θ): {spectral_results['multi_manifold']['fiedler_alignment']:.4f}") |
|
|
| if "error" not in attack_results: |
| print(f"\n3. Attack Simulation:") |
| print(f" Baseline ASR@{attack_config['top_k']}: {attack_results['baseline_asr']:.4f}") |
| print(f" Multi-Manifold ASR@{attack_config['top_k']}: {attack_results['multi_manifold_asr']:.4f}") |
|
|
| |
| save_results = {k: v for k, v in results.items() |
| if k not in ("L_D", "L_R_mm", "L_R_base")} |
| with open(args.output, "w") as f: |
| json.dump(save_results, f, indent=2, default=str) |
| print(f"\nResults saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|