Upload model checkpoints and code
Browse files- README.md +69 -3
- checkpoints/best_operator.pt +3 -0
- checkpoints/final_operator.pt +3 -0
- configs/default.yaml +63 -0
- multi_manifold_retrieval/__init__.py +0 -0
- multi_manifold_retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
- multi_manifold_retrieval/evaluation/__init__.py +0 -0
- multi_manifold_retrieval/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
- multi_manifold_retrieval/evaluation/__pycache__/attack_simulation.cpython-310.pyc +0 -0
- multi_manifold_retrieval/evaluation/__pycache__/retrieval_metrics.cpython-310.pyc +0 -0
- multi_manifold_retrieval/evaluation/__pycache__/spectral_analysis.cpython-310.pyc +0 -0
- multi_manifold_retrieval/evaluation/attack_simulation.py +332 -0
- multi_manifold_retrieval/evaluation/retrieval_metrics.py +99 -0
- multi_manifold_retrieval/evaluation/spectral_analysis.py +205 -0
- multi_manifold_retrieval/models/__init__.py +0 -0
- multi_manifold_retrieval/models/__pycache__/__init__.cpython-310.pyc +0 -0
- multi_manifold_retrieval/models/__pycache__/baseline.cpython-310.pyc +0 -0
- multi_manifold_retrieval/models/__pycache__/cross_manifold_operator.cpython-310.pyc +0 -0
- multi_manifold_retrieval/models/__pycache__/encoders.cpython-310.pyc +0 -0
- multi_manifold_retrieval/models/baseline.py +37 -0
- multi_manifold_retrieval/models/cross_manifold_operator.py +142 -0
- multi_manifold_retrieval/models/encoders.py +60 -0
- multi_manifold_retrieval/training/__init__.py +0 -0
- multi_manifold_retrieval/training/__pycache__/__init__.cpython-310.pyc +0 -0
- multi_manifold_retrieval/training/__pycache__/data.cpython-310.pyc +0 -0
- multi_manifold_retrieval/training/__pycache__/losses.cpython-310.pyc +0 -0
- multi_manifold_retrieval/training/__pycache__/train.cpython-310.pyc +0 -0
- multi_manifold_retrieval/training/data.py +168 -0
- multi_manifold_retrieval/training/losses.py +38 -0
- multi_manifold_retrieval/training/train.py +159 -0
- requirements.txt +10 -0
- results.json +136 -0
- run_experiment.py +261 -0
README.md
CHANGED
|
@@ -1,3 +1,69 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-Manifold Retrieval: Proof of Concept
|
| 2 |
+
|
| 3 |
+
A proof-of-concept implementation of the Multi-Manifold Retrieval defense against spectral poisoning attacks (GeoPoison-RAG) on Retrieval-Augmented Generation systems.
|
| 4 |
+
|
| 5 |
+
## Core Idea
|
| 6 |
+
|
| 7 |
+
Standard RAG systems use a single shared embedding space for queries and documents, making the **document geometry identical to the retrieval geometry**. GeoPoison-RAG exploits this by computing the spectral structure (Fiedler vector) of the document graph Laplacian to find optimal adversarial placement.
|
| 8 |
+
|
| 9 |
+
Multi-Manifold Retrieval **decouples** these geometries by using:
|
| 10 |
+
- Separate query and document manifolds (M_Q and M_D)
|
| 11 |
+
- A non-decomposable cross-manifold relevance operator R(q, d)
|
| 12 |
+
|
| 13 |
+
This breaks the attack because the Laplacian the attacker computes (document space) no longer predicts the Laplacian governing retrieval (cross-manifold).
|
| 14 |
+
|
| 15 |
+
## Project Structure
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
multi_manifold_retrieval/
|
| 19 |
+
├── models/
|
| 20 |
+
│ ├── cross_manifold_operator.py # Construction C: Attention-Geometric Hybrid
|
| 21 |
+
│ ├── encoders.py # Sentence-transformer wrapper
|
| 22 |
+
│ └── baseline.py # Standard cosine similarity baseline
|
| 23 |
+
├── training/
|
| 24 |
+
│ ├── train.py # Training loop
|
| 25 |
+
│ ├── data.py # MS MARCO data loading
|
| 26 |
+
│ └── losses.py # Contrastive loss
|
| 27 |
+
├── evaluation/
|
| 28 |
+
│ ├── spectral_analysis.py # L_D, L_R, spectral discrepancy, Fiedler alignment
|
| 29 |
+
│ ├── retrieval_metrics.py # MRR@10, Recall@100
|
| 30 |
+
│ └── attack_simulation.py # GeoPoison-RAG simulation
|
| 31 |
+
proofs/
|
| 32 |
+
├── proof_theorem_4_3.tex # Spectral Decoupling theorem
|
| 33 |
+
└── proof_theorem_6_1.tex # Query Complexity Lower Bound theorem
|
| 34 |
+
configs/
|
| 35 |
+
└── default.yaml # Hyperparameters
|
| 36 |
+
run_experiment.py # End-to-end pipeline
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Setup
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
pip install -r requirements.txt
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Running
|
| 46 |
+
|
| 47 |
+
Full experiment (train + evaluate + spectral analysis + attack):
|
| 48 |
+
```bash
|
| 49 |
+
python run_experiment.py --config configs/default.yaml
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Skip training and load from checkpoint:
|
| 53 |
+
```bash
|
| 54 |
+
python run_experiment.py --skip-train --checkpoint checkpoints/best_operator.pt
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Key Metrics
|
| 58 |
+
|
| 59 |
+
| Metric | Baseline (expected) | Multi-Manifold (expected) |
|
| 60 |
+
|--------|-------------------|--------------------------|
|
| 61 |
+
| Spectral discrepancy δ | ≈ 0 | > 0 (significant) |
|
| 62 |
+
| Fiedler alignment cos(θ) | ≈ 1 | < 0.5 |
|
| 63 |
+
| ASR@10 | > 0.8 | Significantly lower |
|
| 64 |
+
| MRR@10 | Reference | ≥ 80% of baseline |
|
| 65 |
+
|
| 66 |
+
## Formal Proofs
|
| 67 |
+
|
| 68 |
+
- `proofs/proof_theorem_4_3.tex`: Proves that non-decomposable R with positive cross-manifold curvature guarantees spectral decoupling δ ≥ Ω(κ_R · λ_2(L_D)).
|
| 69 |
+
- `proofs/proof_theorem_6_1.tex`: Proves that an adaptive adversary needs Ω(Vol(M_Q) / V_{d_Q}(ε/κ_R)) oracle queries to reconstruct R.
|
checkpoints/best_operator.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c63df76c2941f40d9f840057dcdc54d506b74cba5dc71d329117c965fdef783
|
| 3 |
+
size 6969493
|
checkpoints/final_operator.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa9be37ddd7e178a56f93eeb0a3259b8409399ee8be89118d84b53b8533318b1
|
| 3 |
+
size 6969525
|
configs/default.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-Manifold Retrieval - Default Configuration
|
| 2 |
+
|
| 3 |
+
seed: 42
|
| 4 |
+
|
| 5 |
+
# Encoder settings
|
| 6 |
+
encoder:
|
| 7 |
+
model_name: "sentence-transformers/all-MiniLM-L6-v2"
|
| 8 |
+
embedding_dim: 384
|
| 9 |
+
freeze: true # Freeze pretrained encoders
|
| 10 |
+
|
| 11 |
+
# Cross-manifold operator (Construction C)
|
| 12 |
+
cross_manifold:
|
| 13 |
+
num_heads: 4
|
| 14 |
+
head_dim: 96 # embedding_dim / num_heads
|
| 15 |
+
value_mlp_hidden: 256
|
| 16 |
+
value_mlp_layers: 2
|
| 17 |
+
dropout: 0.1
|
| 18 |
+
|
| 19 |
+
# Training
|
| 20 |
+
training:
|
| 21 |
+
batch_size: 64
|
| 22 |
+
learning_rate: 2.0e-4
|
| 23 |
+
weight_decay: 1.0e-2
|
| 24 |
+
epochs: 5
|
| 25 |
+
warmup_steps: 500
|
| 26 |
+
max_train_samples: 100000
|
| 27 |
+
num_negatives: 7
|
| 28 |
+
max_seq_length: 128
|
| 29 |
+
fp16: true
|
| 30 |
+
gradient_accumulation_steps: 1
|
| 31 |
+
log_every: 100
|
| 32 |
+
eval_every: 2000
|
| 33 |
+
save_dir: "checkpoints"
|
| 34 |
+
|
| 35 |
+
# Evaluation
|
| 36 |
+
evaluation:
|
| 37 |
+
max_eval_queries: 5000
|
| 38 |
+
metrics:
|
| 39 |
+
- mrr@10
|
| 40 |
+
- recall@100
|
| 41 |
+
|
| 42 |
+
# Spectral analysis
|
| 43 |
+
spectral:
|
| 44 |
+
num_documents: 1000
|
| 45 |
+
num_queries: 500
|
| 46 |
+
k_neighbors: 20 # For sparse Laplacian (optional)
|
| 47 |
+
|
| 48 |
+
# Attack simulation
|
| 49 |
+
attack:
|
| 50 |
+
target_domain: "medical"
|
| 51 |
+
num_target_queries: 100
|
| 52 |
+
top_k: 10
|
| 53 |
+
medical_keywords:
|
| 54 |
+
- "health"
|
| 55 |
+
- "medical"
|
| 56 |
+
- "doctor"
|
| 57 |
+
- "patient"
|
| 58 |
+
- "treatment"
|
| 59 |
+
- "disease"
|
| 60 |
+
- "symptom"
|
| 61 |
+
- "diagnosis"
|
| 62 |
+
- "medicine"
|
| 63 |
+
- "clinical"
|
multi_manifold_retrieval/__init__.py
ADDED
|
File without changes
|
multi_manifold_retrieval/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
multi_manifold_retrieval/evaluation/__init__.py
ADDED
|
File without changes
|
multi_manifold_retrieval/evaluation/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
multi_manifold_retrieval/evaluation/__pycache__/attack_simulation.cpython-310.pyc
ADDED
|
Binary file (9.22 kB). View file
|
|
|
multi_manifold_retrieval/evaluation/__pycache__/retrieval_metrics.cpython-310.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
multi_manifold_retrieval/evaluation/__pycache__/spectral_analysis.cpython-310.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
multi_manifold_retrieval/evaluation/attack_simulation.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simplified GeoPoison-RAG attack simulation.
|
| 2 |
+
|
| 3 |
+
Realistic threat model (matching GeoPoison-RAG Phase 1):
|
| 4 |
+
- Attacker has shadow queries approximating target query distribution.
|
| 5 |
+
- Attacker has access to document embeddings.
|
| 6 |
+
- Attacker builds bipartite query-document graph using COSINE SIMILARITY
|
| 7 |
+
(their model of how retrieval works).
|
| 8 |
+
- Attacker computes Fiedler vector and places adversarial doc at the
|
| 9 |
+
spectral-optimal position in document space.
|
| 10 |
+
|
| 11 |
+
Defense argument:
|
| 12 |
+
- Baseline (cosine sim): attacker's model is correct → high ASR.
|
| 13 |
+
- Multi-manifold (R(q,d)): attacker's model is wrong because R ≠ cosine → lower ASR.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from scipy.sparse.linalg import eigsh
|
| 19 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 20 |
+
|
| 21 |
+
from multi_manifold_retrieval.evaluation.spectral_analysis import compute_document_laplacian
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def select_domain_documents(
|
| 25 |
+
passages: list[str],
|
| 26 |
+
keywords: list[str],
|
| 27 |
+
max_docs: int = 200,
|
| 28 |
+
) -> tuple[list[int], list[str]]:
|
| 29 |
+
"""Select documents belonging to a target domain by keyword matching."""
|
| 30 |
+
indices = []
|
| 31 |
+
texts = []
|
| 32 |
+
for i, text in enumerate(passages):
|
| 33 |
+
text_lower = text.lower()
|
| 34 |
+
if any(kw in text_lower for kw in keywords):
|
| 35 |
+
indices.append(i)
|
| 36 |
+
texts.append(text)
|
| 37 |
+
if len(indices) >= max_docs:
|
| 38 |
+
break
|
| 39 |
+
return indices, texts
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_bipartite_fiedler_placement(
|
| 43 |
+
query_embs: np.ndarray,
|
| 44 |
+
doc_embs: np.ndarray,
|
| 45 |
+
t_nn: int = 20,
|
| 46 |
+
) -> tuple[np.ndarray, dict]:
|
| 47 |
+
"""GeoPoison-RAG Phase 1: bipartite spectral placement (cosine-based).
|
| 48 |
+
|
| 49 |
+
The attacker:
|
| 50 |
+
1. Builds bipartite query-document graph using cosine similarity.
|
| 51 |
+
2. Computes Fiedler vector of the normalized Laplacian.
|
| 52 |
+
3. Extracts document component of Fiedler vector.
|
| 53 |
+
4. Places adversarial doc at Fiedler-weighted centroid of documents.
|
| 54 |
+
|
| 55 |
+
The placement is in DOCUMENT SPACE — the attacker optimizes where to
|
| 56 |
+
place a document, guided by the query-document spectral structure.
|
| 57 |
+
But the attacker assumes retrieval = cosine similarity.
|
| 58 |
+
"""
|
| 59 |
+
nq = query_embs.shape[0]
|
| 60 |
+
nd = doc_embs.shape[0]
|
| 61 |
+
|
| 62 |
+
# Cosine similarity between queries and documents (attacker's model)
|
| 63 |
+
S = cosine_similarity(query_embs, doc_embs) # (nq, nd)
|
| 64 |
+
|
| 65 |
+
# Sparsify: keep top-t per query
|
| 66 |
+
t = min(t_nn, nd - 1)
|
| 67 |
+
S_sparse = np.zeros_like(S)
|
| 68 |
+
for i in range(nq):
|
| 69 |
+
top_idx = np.argpartition(S[i], -t)[-t:]
|
| 70 |
+
S_sparse[i, top_idx] = S[i, top_idx]
|
| 71 |
+
|
| 72 |
+
# Build bipartite adjacency: A = [[0, S], [S^T, 0]]
|
| 73 |
+
n = nq + nd
|
| 74 |
+
A = np.zeros((n, n))
|
| 75 |
+
A[:nq, nq:] = S_sparse
|
| 76 |
+
A[nq:, :nq] = S_sparse.T
|
| 77 |
+
|
| 78 |
+
# Normalized Laplacian: L = I - D^{-1/2} A D^{-1/2}
|
| 79 |
+
degrees = A.sum(axis=1)
|
| 80 |
+
degrees[degrees == 0] = 1.0
|
| 81 |
+
D_inv_sqrt = np.diag(1.0 / np.sqrt(degrees))
|
| 82 |
+
L = np.eye(n) - D_inv_sqrt @ A @ D_inv_sqrt
|
| 83 |
+
|
| 84 |
+
# Fiedler vector (2nd smallest eigenvector)
|
| 85 |
+
k = min(3, n - 1)
|
| 86 |
+
eigenvalues, eigenvectors = eigsh(L, k=k, which="SM")
|
| 87 |
+
sorted_idx = np.argsort(eigenvalues)
|
| 88 |
+
fiedler_vec = eigenvectors[:, sorted_idx[1]]
|
| 89 |
+
fiedler_val = eigenvalues[sorted_idx[1]]
|
| 90 |
+
|
| 91 |
+
# Extract document component and use as weights
|
| 92 |
+
doc_component = fiedler_vec[nq:]
|
| 93 |
+
weights = np.abs(doc_component)
|
| 94 |
+
weights = weights / (weights.sum() + 1e-12)
|
| 95 |
+
|
| 96 |
+
# Fiedler-weighted centroid of documents
|
| 97 |
+
adv_embedding = (weights[:, None] * doc_embs).sum(axis=0)
|
| 98 |
+
|
| 99 |
+
# L2-normalize
|
| 100 |
+
norm = np.linalg.norm(adv_embedding)
|
| 101 |
+
if norm > 0:
|
| 102 |
+
adv_embedding = adv_embedding / norm
|
| 103 |
+
|
| 104 |
+
info = {
|
| 105 |
+
"method": "bipartite_fiedler",
|
| 106 |
+
"fiedler_eigenvalue": float(fiedler_val),
|
| 107 |
+
"weight_entropy": float(-np.sum(weights * np.log(weights + 1e-12))),
|
| 108 |
+
"max_weight": float(weights.max()),
|
| 109 |
+
"adv_mean_cos_to_queries": float(
|
| 110 |
+
cosine_similarity(adv_embedding.reshape(1, -1), query_embs).mean()
|
| 111 |
+
),
|
| 112 |
+
"adv_mean_cos_to_docs": float(
|
| 113 |
+
cosine_similarity(adv_embedding.reshape(1, -1), doc_embs).mean()
|
| 114 |
+
),
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
return adv_embedding, info
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def compute_doconly_fiedler_placement(doc_embs: np.ndarray) -> tuple[np.ndarray, dict]:
|
| 121 |
+
"""Document-only Fiedler placement (no query access).
|
| 122 |
+
|
| 123 |
+
Weaker attacker that only has document embeddings.
|
| 124 |
+
Uses document-space Laplacian L_D directly.
|
| 125 |
+
"""
|
| 126 |
+
n = doc_embs.shape[0]
|
| 127 |
+
if n < 3:
|
| 128 |
+
centroid = doc_embs.mean(axis=0)
|
| 129 |
+
return centroid / np.linalg.norm(centroid), {"method": "centroid_fallback"}
|
| 130 |
+
|
| 131 |
+
L_D, _ = compute_document_laplacian(doc_embs)
|
| 132 |
+
|
| 133 |
+
k = min(3, n - 1)
|
| 134 |
+
eigenvalues, eigenvectors = eigsh(L_D, k=k, which="SM")
|
| 135 |
+
sorted_idx = np.argsort(eigenvalues)
|
| 136 |
+
fiedler_vec = eigenvectors[:, sorted_idx[1]]
|
| 137 |
+
fiedler_val = eigenvalues[sorted_idx[1]]
|
| 138 |
+
|
| 139 |
+
weights = np.abs(fiedler_vec)
|
| 140 |
+
weights = weights / (weights.sum() + 1e-12)
|
| 141 |
+
|
| 142 |
+
adv_embedding = (weights[:, None] * doc_embs).sum(axis=0)
|
| 143 |
+
norm = np.linalg.norm(adv_embedding)
|
| 144 |
+
if norm > 0:
|
| 145 |
+
adv_embedding = adv_embedding / norm
|
| 146 |
+
|
| 147 |
+
return adv_embedding, {
|
| 148 |
+
"method": "doconly_fiedler",
|
| 149 |
+
"fiedler_eigenvalue": float(fiedler_val),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def compute_asr_threshold(
|
| 154 |
+
query_embeddings: torch.Tensor,
|
| 155 |
+
corpus_embeddings: torch.Tensor,
|
| 156 |
+
adv_embedding: torch.Tensor,
|
| 157 |
+
operator,
|
| 158 |
+
top_k: int = 10,
|
| 159 |
+
device: str = "cpu",
|
| 160 |
+
batch_size: int = 50,
|
| 161 |
+
) -> tuple[float, dict]:
|
| 162 |
+
"""Compute ASR@k using per-query threshold (oracle-style).
|
| 163 |
+
|
| 164 |
+
For each query, the k-th highest corpus score is the threshold.
|
| 165 |
+
Attack succeeds if the adversarial doc's score >= threshold.
|
| 166 |
+
Mirrors gp_rag/plan_single.py oracle check.
|
| 167 |
+
"""
|
| 168 |
+
num_queries = query_embeddings.shape[0]
|
| 169 |
+
corpus_emb = corpus_embeddings.to(device)
|
| 170 |
+
adv_emb = adv_embedding.to(device)
|
| 171 |
+
|
| 172 |
+
operator.eval()
|
| 173 |
+
successes = 0
|
| 174 |
+
margins = []
|
| 175 |
+
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
for start in range(0, num_queries, batch_size):
|
| 178 |
+
end = min(start + batch_size, num_queries)
|
| 179 |
+
q_batch = query_embeddings[start:end].to(device)
|
| 180 |
+
bs = q_batch.shape[0]
|
| 181 |
+
|
| 182 |
+
# Score adversarial document
|
| 183 |
+
adv_expanded = adv_emb.unsqueeze(0).expand(bs, -1)
|
| 184 |
+
adv_scores = operator(q_batch, adv_expanded)
|
| 185 |
+
|
| 186 |
+
# Score corpus documents
|
| 187 |
+
corpus_scores = operator.compute_pairwise(q_batch, corpus_emb)
|
| 188 |
+
|
| 189 |
+
# k-th highest corpus score = threshold
|
| 190 |
+
topk_vals, _ = torch.topk(corpus_scores, top_k, dim=1)
|
| 191 |
+
thresholds = topk_vals[:, -1]
|
| 192 |
+
|
| 193 |
+
for j in range(bs):
|
| 194 |
+
margin = float(adv_scores[j].item() - thresholds[j].item())
|
| 195 |
+
margins.append(margin)
|
| 196 |
+
if adv_scores[j] >= thresholds[j]:
|
| 197 |
+
successes += 1
|
| 198 |
+
|
| 199 |
+
asr = successes / num_queries
|
| 200 |
+
margins_arr = np.array(margins)
|
| 201 |
+
info = {
|
| 202 |
+
"mean_margin": float(margins_arr.mean()),
|
| 203 |
+
"median_margin": float(np.median(margins_arr)),
|
| 204 |
+
"p25_margin": float(np.percentile(margins_arr, 25)),
|
| 205 |
+
"fraction_positive_margin": float((margins_arr >= 0).mean()),
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return asr, info
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def run_attack_simulation(
|
| 212 |
+
encoder,
|
| 213 |
+
operator,
|
| 214 |
+
baseline_operator,
|
| 215 |
+
passages: list[str],
|
| 216 |
+
passage_embeddings_torch: torch.Tensor,
|
| 217 |
+
target_query_texts: list[str],
|
| 218 |
+
medical_keywords: list[str],
|
| 219 |
+
top_k: int = 10,
|
| 220 |
+
max_domain_docs: int = 200,
|
| 221 |
+
device: str = "cpu",
|
| 222 |
+
) -> dict:
|
| 223 |
+
"""Run GeoPoison-RAG attack simulation.
|
| 224 |
+
|
| 225 |
+
Tests two attacker models:
|
| 226 |
+
1. Bipartite Fiedler (realistic): attacker has shadow queries + docs,
|
| 227 |
+
builds cosine-based bipartite graph, optimizes in document space.
|
| 228 |
+
2. Doc-only Fiedler (weaker): attacker has only document embeddings.
|
| 229 |
+
|
| 230 |
+
Both assume cosine similarity governs retrieval. The defense breaks
|
| 231 |
+
this assumption via the cross-manifold operator R.
|
| 232 |
+
"""
|
| 233 |
+
print("\n=== Attack Simulation ===", flush=True)
|
| 234 |
+
|
| 235 |
+
# Step 1: Select target domain documents
|
| 236 |
+
domain_indices, domain_texts = select_domain_documents(
|
| 237 |
+
passages, medical_keywords, max_domain_docs
|
| 238 |
+
)
|
| 239 |
+
print(f"Selected {len(domain_indices)} domain documents.", flush=True)
|
| 240 |
+
|
| 241 |
+
if len(domain_indices) < 5:
|
| 242 |
+
print("Warning: Too few domain documents found.")
|
| 243 |
+
return {"error": "insufficient domain documents"}
|
| 244 |
+
|
| 245 |
+
domain_embs_np = passage_embeddings_torch[domain_indices].cpu().numpy()
|
| 246 |
+
domain_corpus = passage_embeddings_torch[domain_indices]
|
| 247 |
+
|
| 248 |
+
# Step 2: Encode target queries (attacker's shadow queries)
|
| 249 |
+
print(f"Encoding {len(target_query_texts)} target queries...", flush=True)
|
| 250 |
+
query_embeddings = encoder.encode_queries(target_query_texts, show_progress=False)
|
| 251 |
+
q_np = query_embeddings.cpu().numpy()
|
| 252 |
+
|
| 253 |
+
# Step 3a: Bipartite Fiedler placement (realistic attacker)
|
| 254 |
+
print("\nComputing bipartite Fiedler placement (attacker has shadow queries)...", flush=True)
|
| 255 |
+
adv_bipartite_np, bp_info = build_bipartite_fiedler_placement(
|
| 256 |
+
q_np, domain_embs_np, t_nn=min(20, len(domain_indices) - 1)
|
| 257 |
+
)
|
| 258 |
+
adv_bipartite = torch.tensor(adv_bipartite_np, dtype=torch.float32)
|
| 259 |
+
print(f" Fiedler eigenvalue: {bp_info['fiedler_eigenvalue']:.6f}", flush=True)
|
| 260 |
+
print(f" Adv mean cos to queries: {bp_info['adv_mean_cos_to_queries']:.4f}", flush=True)
|
| 261 |
+
print(f" Adv mean cos to docs: {bp_info['adv_mean_cos_to_docs']:.4f}", flush=True)
|
| 262 |
+
|
| 263 |
+
# Step 3b: Doc-only Fiedler placement (weaker attacker)
|
| 264 |
+
print("\nComputing doc-only Fiedler placement (no query access)...", flush=True)
|
| 265 |
+
adv_doconly_np, do_info = compute_doconly_fiedler_placement(domain_embs_np)
|
| 266 |
+
adv_doconly = torch.tensor(adv_doconly_np, dtype=torch.float32)
|
| 267 |
+
|
| 268 |
+
# Step 4: Measure ASR for bipartite attack
|
| 269 |
+
print(f"\n--- Bipartite Fiedler Attack (realistic GeoPoison-RAG) ---", flush=True)
|
| 270 |
+
|
| 271 |
+
asr_bp_base, bp_base_info = compute_asr_threshold(
|
| 272 |
+
query_embeddings, domain_corpus, adv_bipartite,
|
| 273 |
+
baseline_operator, top_k, device
|
| 274 |
+
)
|
| 275 |
+
print(f" Baseline ASR@{top_k}: {asr_bp_base:.4f} (mean margin: {bp_base_info['mean_margin']:.4f})", flush=True)
|
| 276 |
+
|
| 277 |
+
asr_bp_mm, bp_mm_info = compute_asr_threshold(
|
| 278 |
+
query_embeddings, domain_corpus, adv_bipartite,
|
| 279 |
+
operator, top_k, device
|
| 280 |
+
)
|
| 281 |
+
print(f" Multi-manifold ASR@{top_k}: {asr_bp_mm:.4f} (mean margin: {bp_mm_info['mean_margin']:.4f})", flush=True)
|
| 282 |
+
|
| 283 |
+
# Step 5: Measure ASR for doc-only attack
|
| 284 |
+
print(f"\n--- Doc-only Fiedler Attack (weaker attacker) ---", flush=True)
|
| 285 |
+
|
| 286 |
+
asr_do_base, do_base_info = compute_asr_threshold(
|
| 287 |
+
query_embeddings, domain_corpus, adv_doconly,
|
| 288 |
+
baseline_operator, top_k, device
|
| 289 |
+
)
|
| 290 |
+
print(f" Baseline ASR@{top_k}: {asr_do_base:.4f} (mean margin: {do_base_info['mean_margin']:.4f})", flush=True)
|
| 291 |
+
|
| 292 |
+
asr_do_mm, do_mm_info = compute_asr_threshold(
|
| 293 |
+
query_embeddings, domain_corpus, adv_doconly,
|
| 294 |
+
operator, top_k, device
|
| 295 |
+
)
|
| 296 |
+
print(f" Multi-manifold ASR@{top_k}: {asr_do_mm:.4f} (mean margin: {do_mm_info['mean_margin']:.4f})", flush=True)
|
| 297 |
+
|
| 298 |
+
# Summary
|
| 299 |
+
results = {
|
| 300 |
+
"bipartite_attack": {
|
| 301 |
+
"baseline_asr": asr_bp_base,
|
| 302 |
+
"multi_manifold_asr": asr_bp_mm,
|
| 303 |
+
"baseline_margins": bp_base_info,
|
| 304 |
+
"multi_manifold_margins": bp_mm_info,
|
| 305 |
+
"placement_info": bp_info,
|
| 306 |
+
},
|
| 307 |
+
"doconly_attack": {
|
| 308 |
+
"baseline_asr": asr_do_base,
|
| 309 |
+
"multi_manifold_asr": asr_do_mm,
|
| 310 |
+
"baseline_margins": do_base_info,
|
| 311 |
+
"multi_manifold_margins": do_mm_info,
|
| 312 |
+
"placement_info": do_info,
|
| 313 |
+
},
|
| 314 |
+
"num_domain_docs": len(domain_indices),
|
| 315 |
+
"num_target_queries": len(target_query_texts),
|
| 316 |
+
"top_k": top_k,
|
| 317 |
+
# For backward compat with summary printing
|
| 318 |
+
"baseline_asr": asr_bp_base,
|
| 319 |
+
"multi_manifold_asr": asr_bp_mm,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
def _reduction(base, mm):
|
| 323 |
+
return (1 - mm / max(base, 1e-9)) * 100
|
| 324 |
+
|
| 325 |
+
print(f"\n=== Attack Results Summary ===", flush=True)
|
| 326 |
+
print(f" Baseline Multi-Manifold Reduction", flush=True)
|
| 327 |
+
print(f" Bipartite (realistic): {asr_bp_base:.4f} {asr_bp_mm:.4f}"
|
| 328 |
+
f" {_reduction(asr_bp_base, asr_bp_mm):.1f}%", flush=True)
|
| 329 |
+
print(f" Doc-only (weaker): {asr_do_base:.4f} {asr_do_mm:.4f}"
|
| 330 |
+
f" {_reduction(asr_do_base, asr_do_mm):.1f}%", flush=True)
|
| 331 |
+
|
| 332 |
+
return results
|
multi_manifold_retrieval/evaluation/retrieval_metrics.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Retrieval evaluation metrics: MRR@10 and Recall@100."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_retrieval_metrics(
|
| 8 |
+
query_embeddings: torch.Tensor,
|
| 9 |
+
doc_embeddings: torch.Tensor,
|
| 10 |
+
operator,
|
| 11 |
+
query_texts: list[str],
|
| 12 |
+
positive_passages: list[list[str]],
|
| 13 |
+
all_passages: list[str],
|
| 14 |
+
passage_embeddings: torch.Tensor,
|
| 15 |
+
device: str = "cpu",
|
| 16 |
+
batch_size: int = 32,
|
| 17 |
+
) -> dict:
|
| 18 |
+
"""Compute MRR@10 and Recall@100.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
query_embeddings: (num_queries, d) query embeddings.
|
| 22 |
+
doc_embeddings: Not used directly (passage_embeddings used instead).
|
| 23 |
+
operator: Relevance operator (cross-manifold or baseline).
|
| 24 |
+
query_texts: List of query strings.
|
| 25 |
+
positive_passages: List of lists of positive passage texts per query.
|
| 26 |
+
all_passages: Flat list of all candidate passages.
|
| 27 |
+
passage_embeddings: (num_passages, d) embeddings for all_passages.
|
| 28 |
+
device: Computation device.
|
| 29 |
+
batch_size: Batch size for scoring.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Dict with mrr@10 and recall@100.
|
| 33 |
+
"""
|
| 34 |
+
num_queries = query_embeddings.shape[0]
|
| 35 |
+
num_passages = passage_embeddings.shape[0]
|
| 36 |
+
|
| 37 |
+
# Build positive passage index: for each query, which passage indices are relevant
|
| 38 |
+
passage_to_idx = {text: idx for idx, text in enumerate(all_passages)}
|
| 39 |
+
positive_indices = []
|
| 40 |
+
for pos_list in positive_passages:
|
| 41 |
+
indices = set()
|
| 42 |
+
for text in pos_list:
|
| 43 |
+
if text in passage_to_idx:
|
| 44 |
+
indices.add(passage_to_idx[text])
|
| 45 |
+
positive_indices.append(indices)
|
| 46 |
+
|
| 47 |
+
passage_embeddings = passage_embeddings.to(device)
|
| 48 |
+
operator.eval()
|
| 49 |
+
|
| 50 |
+
mrr_sum = 0.0
|
| 51 |
+
recall_100_sum = 0.0
|
| 52 |
+
valid_queries = 0
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
for i in range(0, num_queries, batch_size):
|
| 56 |
+
end = min(i + batch_size, num_queries)
|
| 57 |
+
q_batch = query_embeddings[i:end].to(device) # (bs, d)
|
| 58 |
+
|
| 59 |
+
# Score all passages: (bs, num_passages)
|
| 60 |
+
scores = operator.compute_pairwise(q_batch, passage_embeddings)
|
| 61 |
+
scores_np = scores.cpu().numpy()
|
| 62 |
+
|
| 63 |
+
for j in range(scores_np.shape[0]):
|
| 64 |
+
query_idx = i + j
|
| 65 |
+
pos_set = positive_indices[query_idx]
|
| 66 |
+
if not pos_set:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
# Rank by score (descending)
|
| 70 |
+
ranked = np.argsort(-scores_np[j])
|
| 71 |
+
|
| 72 |
+
# MRR@10
|
| 73 |
+
rr = 0.0
|
| 74 |
+
for rank, doc_idx in enumerate(ranked[:10]):
|
| 75 |
+
if doc_idx in pos_set:
|
| 76 |
+
rr = 1.0 / (rank + 1)
|
| 77 |
+
break
|
| 78 |
+
mrr_sum += rr
|
| 79 |
+
|
| 80 |
+
# Recall@100
|
| 81 |
+
top_100 = set(ranked[:100].tolist())
|
| 82 |
+
recall = len(pos_set & top_100) / len(pos_set)
|
| 83 |
+
recall_100_sum += recall
|
| 84 |
+
|
| 85 |
+
valid_queries += 1
|
| 86 |
+
|
| 87 |
+
mrr_at_10 = mrr_sum / valid_queries if valid_queries > 0 else 0.0
|
| 88 |
+
recall_at_100 = recall_100_sum / valid_queries if valid_queries > 0 else 0.0
|
| 89 |
+
|
| 90 |
+
results = {
|
| 91 |
+
"mrr@10": mrr_at_10,
|
| 92 |
+
"recall@100": recall_at_100,
|
| 93 |
+
"num_queries": valid_queries,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
print(f"MRR@10: {mrr_at_10:.4f} | Recall@100: {recall_at_100:.4f} "
|
| 97 |
+
f"({valid_queries} queries)")
|
| 98 |
+
|
| 99 |
+
return results
|
multi_manifold_retrieval/evaluation/spectral_analysis.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Spectral analysis: compute L_D, L_R, spectral discrepancy δ, and Fiedler alignment cos(θ)."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from scipy import sparse
|
| 6 |
+
from scipy.sparse.linalg import eigsh
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def compute_document_laplacian(doc_embeddings: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 10 |
+
"""Compute the document-space graph Laplacian L_D.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
doc_embeddings: (n, d) array of L2-normalized document embeddings.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
L_D: (n, n) graph Laplacian.
|
| 17 |
+
W_D: (n, n) cosine similarity matrix.
|
| 18 |
+
"""
|
| 19 |
+
# Cosine similarity (embeddings are already L2-normalized)
|
| 20 |
+
W_D = doc_embeddings @ doc_embeddings.T
|
| 21 |
+
|
| 22 |
+
# Clip to [0, 1] to ensure non-negative weights
|
| 23 |
+
W_D = np.clip(W_D, 0, None)
|
| 24 |
+
np.fill_diagonal(W_D, 0) # No self-loops
|
| 25 |
+
|
| 26 |
+
# Degree matrix and Laplacian
|
| 27 |
+
degrees = W_D.sum(axis=1)
|
| 28 |
+
D_D = np.diag(degrees)
|
| 29 |
+
L_D = D_D - W_D
|
| 30 |
+
|
| 31 |
+
return L_D, W_D
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_retrieval_laplacian(
|
| 35 |
+
doc_embeddings: torch.Tensor,
|
| 36 |
+
query_embeddings: torch.Tensor,
|
| 37 |
+
operator,
|
| 38 |
+
device: str = "cpu",
|
| 39 |
+
batch_size: int = 50,
|
| 40 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 41 |
+
"""Compute the retrieval Laplacian L_R.
|
| 42 |
+
|
| 43 |
+
(W_R)_{ij} = (1/|Q|) * sum_q R(q, d_i) * R(q, d_j)
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
doc_embeddings: (n, d) document embeddings tensor.
|
| 47 |
+
query_embeddings: (m, d) query embeddings tensor.
|
| 48 |
+
operator: Cross-manifold operator or baseline operator.
|
| 49 |
+
device: Computation device.
|
| 50 |
+
batch_size: Number of queries to process at once.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
L_R: (n, n) retrieval Laplacian.
|
| 54 |
+
W_R: (n, n) retrieval similarity matrix.
|
| 55 |
+
"""
|
| 56 |
+
n = doc_embeddings.shape[0]
|
| 57 |
+
m = query_embeddings.shape[0]
|
| 58 |
+
|
| 59 |
+
doc_embeddings = doc_embeddings.to(device)
|
| 60 |
+
query_embeddings = query_embeddings.to(device)
|
| 61 |
+
|
| 62 |
+
# Accumulate W_R = (1/m) * R^T R where R_{ki} = R(q_k, d_i)
|
| 63 |
+
W_R = np.zeros((n, n), dtype=np.float64)
|
| 64 |
+
|
| 65 |
+
operator.eval()
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
for start in range(0, m, batch_size):
|
| 68 |
+
end = min(start + batch_size, m)
|
| 69 |
+
q_batch = query_embeddings[start:end] # (bs, d)
|
| 70 |
+
|
| 71 |
+
# Compute R(q, d) for all docs: (bs, n)
|
| 72 |
+
scores = operator.compute_pairwise(q_batch, doc_embeddings)
|
| 73 |
+
scores_np = scores.cpu().numpy().astype(np.float64) # (bs, n)
|
| 74 |
+
|
| 75 |
+
# Outer product accumulation: W_R += scores^T @ scores
|
| 76 |
+
W_R += scores_np.T @ scores_np
|
| 77 |
+
|
| 78 |
+
W_R /= m
|
| 79 |
+
|
| 80 |
+
# Ensure non-negative and zero diagonal
|
| 81 |
+
W_R = np.clip(W_R, 0, None)
|
| 82 |
+
np.fill_diagonal(W_R, 0)
|
| 83 |
+
|
| 84 |
+
# Laplacian
|
| 85 |
+
degrees = W_R.sum(axis=1)
|
| 86 |
+
D_R = np.diag(degrees)
|
| 87 |
+
L_R = D_R - W_R
|
| 88 |
+
|
| 89 |
+
return L_R, W_R
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_spectral_discrepancy(L_D: np.ndarray, L_R: np.ndarray,
|
| 93 |
+
num_eigenvalues: int = 50) -> float:
|
| 94 |
+
"""Compute spectral discrepancy δ = ||σ(L_D) - σ(L_R)||_2.
|
| 95 |
+
|
| 96 |
+
Uses the smallest num_eigenvalues eigenvalues (normalized).
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
L_D: Document-space Laplacian.
|
| 100 |
+
L_R: Retrieval Laplacian.
|
| 101 |
+
num_eigenvalues: Number of eigenvalues to compare.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
δ: Spectral discrepancy.
|
| 105 |
+
"""
|
| 106 |
+
n = L_D.shape[0]
|
| 107 |
+
k = min(num_eigenvalues, n - 2)
|
| 108 |
+
|
| 109 |
+
# Compute smallest eigenvalues (Laplacians have smallest eigenvalue = 0)
|
| 110 |
+
eigs_D = eigsh(L_D, k=k, which="SM", return_eigenvectors=False)
|
| 111 |
+
eigs_R = eigsh(L_R, k=k, which="SM", return_eigenvectors=False)
|
| 112 |
+
|
| 113 |
+
# Sort
|
| 114 |
+
eigs_D = np.sort(eigs_D)
|
| 115 |
+
eigs_R = np.sort(eigs_R)
|
| 116 |
+
|
| 117 |
+
# Normalize so max eigenvalue = 1
|
| 118 |
+
max_D = eigs_D[-1] if eigs_D[-1] > 0 else 1.0
|
| 119 |
+
max_R = eigs_R[-1] if eigs_R[-1] > 0 else 1.0
|
| 120 |
+
eigs_D_norm = eigs_D / max_D
|
| 121 |
+
eigs_R_norm = eigs_R / max_R
|
| 122 |
+
|
| 123 |
+
delta = np.linalg.norm(eigs_D_norm - eigs_R_norm)
|
| 124 |
+
return delta
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def compute_fiedler_alignment(L_D: np.ndarray, L_R: np.ndarray) -> float:
|
| 128 |
+
"""Compute Fiedler vector alignment cos(θ) = |v_2(L_D)^T v_2(L_R)| / (||v_2(L_D)|| * ||v_2(L_R)||).
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
L_D: Document-space Laplacian.
|
| 132 |
+
L_R: Retrieval Laplacian.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
cos(θ): Absolute cosine of angle between Fiedler vectors (1 = aligned, 0 = orthogonal).
|
| 136 |
+
"""
|
| 137 |
+
# Compute the two smallest eigenvalues/vectors
|
| 138 |
+
_, vecs_D = eigsh(L_D, k=2, which="SM")
|
| 139 |
+
_, vecs_R = eigsh(L_R, k=2, which="SM")
|
| 140 |
+
|
| 141 |
+
# Fiedler vector = eigenvector for 2nd smallest eigenvalue (index 1 after sorting)
|
| 142 |
+
v2_D = vecs_D[:, 1]
|
| 143 |
+
v2_R = vecs_R[:, 1]
|
| 144 |
+
|
| 145 |
+
# Normalize
|
| 146 |
+
v2_D = v2_D / np.linalg.norm(v2_D)
|
| 147 |
+
v2_R = v2_R / np.linalg.norm(v2_R)
|
| 148 |
+
|
| 149 |
+
# Absolute cosine similarity
|
| 150 |
+
cos_theta = np.abs(np.dot(v2_D, v2_R))
|
| 151 |
+
return cos_theta
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def run_spectral_analysis(
|
| 155 |
+
doc_embeddings_np: np.ndarray,
|
| 156 |
+
doc_embeddings_torch: torch.Tensor,
|
| 157 |
+
query_embeddings_torch: torch.Tensor,
|
| 158 |
+
operator,
|
| 159 |
+
baseline_operator,
|
| 160 |
+
device: str = "cpu",
|
| 161 |
+
) -> dict:
|
| 162 |
+
"""Run full spectral analysis for both multi-manifold and baseline.
|
| 163 |
+
|
| 164 |
+
Returns dict with all metrics.
|
| 165 |
+
"""
|
| 166 |
+
print("Computing document-space Laplacian L_D...")
|
| 167 |
+
L_D, W_D = compute_document_laplacian(doc_embeddings_np)
|
| 168 |
+
|
| 169 |
+
print("Computing retrieval Laplacian L_R (multi-manifold)...")
|
| 170 |
+
L_R_mm, W_R_mm = compute_retrieval_laplacian(
|
| 171 |
+
doc_embeddings_torch, query_embeddings_torch, operator, device
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
print("Computing retrieval Laplacian L_R (baseline)...")
|
| 175 |
+
L_R_base, W_R_base = compute_retrieval_laplacian(
|
| 176 |
+
doc_embeddings_torch, query_embeddings_torch, baseline_operator, device
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
print("Computing spectral discrepancy and Fiedler alignment...")
|
| 180 |
+
num_eigs = min(50, doc_embeddings_np.shape[0] - 2)
|
| 181 |
+
|
| 182 |
+
delta_mm = compute_spectral_discrepancy(L_D, L_R_mm, num_eigs)
|
| 183 |
+
delta_base = compute_spectral_discrepancy(L_D, L_R_base, num_eigs)
|
| 184 |
+
cos_theta_mm = compute_fiedler_alignment(L_D, L_R_mm)
|
| 185 |
+
cos_theta_base = compute_fiedler_alignment(L_D, L_R_base)
|
| 186 |
+
|
| 187 |
+
results = {
|
| 188 |
+
"multi_manifold": {
|
| 189 |
+
"spectral_discrepancy": delta_mm,
|
| 190 |
+
"fiedler_alignment": cos_theta_mm,
|
| 191 |
+
},
|
| 192 |
+
"baseline": {
|
| 193 |
+
"spectral_discrepancy": delta_base,
|
| 194 |
+
"fiedler_alignment": cos_theta_base,
|
| 195 |
+
},
|
| 196 |
+
"L_D": L_D,
|
| 197 |
+
"L_R_mm": L_R_mm,
|
| 198 |
+
"L_R_base": L_R_base,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
print(f"\n=== Spectral Analysis Results ===")
|
| 202 |
+
print(f"Multi-Manifold: δ = {delta_mm:.4f}, cos(θ) = {cos_theta_mm:.4f}")
|
| 203 |
+
print(f"Baseline: δ = {delta_base:.4f}, cos(θ) = {cos_theta_base:.4f}")
|
| 204 |
+
|
| 205 |
+
return results
|
multi_manifold_retrieval/models/__init__.py
ADDED
|
File without changes
|
multi_manifold_retrieval/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (179 Bytes). View file
|
|
|
multi_manifold_retrieval/models/__pycache__/baseline.cpython-310.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
multi_manifold_retrieval/models/__pycache__/cross_manifold_operator.cpython-310.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
multi_manifold_retrieval/models/__pycache__/encoders.cpython-310.pyc
ADDED
|
Binary file (2.76 kB). View file
|
|
|
multi_manifold_retrieval/models/baseline.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standard dual encoder baseline: R(q, d) = cosine_similarity(q, d)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaselineOperator(nn.Module):
|
| 8 |
+
"""Decomposable baseline: R(q, d) = q^T d (cosine similarity on normalized embeddings)."""
|
| 9 |
+
|
| 10 |
+
def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
"""Compute cosine similarity.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
q: Query embeddings, shape (batch_size, embedding_dim), L2-normalized.
|
| 15 |
+
d: Document embeddings, shape (batch_size, num_docs, embedding_dim) or
|
| 16 |
+
(batch_size, embedding_dim).
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Similarity scores.
|
| 20 |
+
"""
|
| 21 |
+
if d.dim() == 2:
|
| 22 |
+
return torch.sum(q * d, dim=-1)
|
| 23 |
+
else:
|
| 24 |
+
return torch.einsum("bd,bnd->bn", q, d)
|
| 25 |
+
|
| 26 |
+
def compute_pairwise(self, q: torch.Tensor,
|
| 27 |
+
docs: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""Compute cosine similarity for all query-document pairs.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
q: (num_queries, embedding_dim)
|
| 32 |
+
docs: (num_docs, embedding_dim)
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
(num_queries, num_docs)
|
| 36 |
+
"""
|
| 37 |
+
return torch.mm(q, docs.t())
|
multi_manifold_retrieval/models/cross_manifold_operator.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Construction C: Attention-Geometric Hybrid cross-manifold operator.
|
| 2 |
+
|
| 3 |
+
R(q, d) = sum_{h=1}^{H} softmax((W_Q^h q)^T (W_K^h d) / sqrt(d_h)) * v^h(q, d)
|
| 4 |
+
|
| 5 |
+
where v^h(q, d) is a learned query-dependent value function parameterized as
|
| 6 |
+
a small MLP taking [q; d; q * d] as input.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ValueMLP(nn.Module):
|
| 15 |
+
"""Learned value function v^h(q, d) for a single attention head.
|
| 16 |
+
|
| 17 |
+
Takes concatenation [q; d; q * d] as input, outputs a scalar.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, embedding_dim: int, hidden_dim: int = 256,
|
| 21 |
+
num_layers: int = 2, dropout: float = 0.1):
|
| 22 |
+
super().__init__()
|
| 23 |
+
input_dim = 3 * embedding_dim # [q; d; q*d]
|
| 24 |
+
layers = []
|
| 25 |
+
in_dim = input_dim
|
| 26 |
+
for _ in range(num_layers):
|
| 27 |
+
layers.extend([
|
| 28 |
+
nn.Linear(in_dim, hidden_dim),
|
| 29 |
+
nn.GELU(),
|
| 30 |
+
nn.Dropout(dropout),
|
| 31 |
+
])
|
| 32 |
+
in_dim = hidden_dim
|
| 33 |
+
layers.append(nn.Linear(hidden_dim, 1))
|
| 34 |
+
self.mlp = nn.Sequential(*layers)
|
| 35 |
+
|
| 36 |
+
def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
"""Compute v^h(q, d).
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
q: Query embeddings, shape (batch, embed_dim) or (batch, num_docs, embed_dim)
|
| 41 |
+
d: Document embeddings, shape (batch, embed_dim) or (batch, num_docs, embed_dim)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Scalar values, shape matching the batch/doc dimensions.
|
| 45 |
+
"""
|
| 46 |
+
x = torch.cat([q, d, q * d], dim=-1)
|
| 47 |
+
return self.mlp(x).squeeze(-1)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CrossManifoldOperator(nn.Module):
|
| 51 |
+
"""Attention-Geometric Hybrid (Construction C).
|
| 52 |
+
|
| 53 |
+
Implements the cross-manifold relevance operator R(q, d) as a
|
| 54 |
+
multi-head attention mechanism with learned value functions.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, embedding_dim: int, num_heads: int = 4,
|
| 58 |
+
value_hidden_dim: int = 256, value_num_layers: int = 2,
|
| 59 |
+
dropout: float = 0.1):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.embedding_dim = embedding_dim
|
| 62 |
+
self.num_heads = num_heads
|
| 63 |
+
self.head_dim = embedding_dim // num_heads
|
| 64 |
+
assert embedding_dim % num_heads == 0, \
|
| 65 |
+
f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
| 66 |
+
|
| 67 |
+
# Per-head query and key projections
|
| 68 |
+
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
| 69 |
+
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
| 70 |
+
|
| 71 |
+
# Per-head value MLPs
|
| 72 |
+
self.value_mlps = nn.ModuleList([
|
| 73 |
+
ValueMLP(embedding_dim, value_hidden_dim, value_num_layers, dropout)
|
| 74 |
+
for _ in range(num_heads)
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
self._init_weights()
|
| 78 |
+
|
| 79 |
+
def _init_weights(self):
|
| 80 |
+
nn.init.xavier_uniform_(self.W_Q.weight)
|
| 81 |
+
nn.init.xavier_uniform_(self.W_K.weight)
|
| 82 |
+
|
| 83 |
+
def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""Compute R(q, d) = cross-manifold relevance score.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
q: Query embeddings, shape (batch_size, embedding_dim)
|
| 88 |
+
d: Document embeddings, shape (batch_size, num_docs, embedding_dim)
|
| 89 |
+
or (batch_size, embedding_dim) for single document.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Relevance scores, shape (batch_size, num_docs) or (batch_size,).
|
| 93 |
+
"""
|
| 94 |
+
single_doc = d.dim() == 2
|
| 95 |
+
if single_doc:
|
| 96 |
+
d = d.unsqueeze(1) # (batch, 1, embed_dim)
|
| 97 |
+
|
| 98 |
+
batch_size, num_docs, _ = d.shape
|
| 99 |
+
|
| 100 |
+
# Project queries and keys: (batch, embed_dim) -> (batch, num_heads, head_dim)
|
| 101 |
+
q_proj = self.W_Q(q).view(batch_size, self.num_heads, self.head_dim)
|
| 102 |
+
# (batch, num_docs, embed_dim) -> (batch, num_docs, num_heads, head_dim)
|
| 103 |
+
d_proj = self.W_K(d).view(batch_size, num_docs, self.num_heads, self.head_dim)
|
| 104 |
+
|
| 105 |
+
# Attention scores: (batch, num_docs, num_heads)
|
| 106 |
+
scale = self.head_dim ** 0.5
|
| 107 |
+
attn = torch.einsum("bhd,bnhd->bnh", q_proj, d_proj) / scale
|
| 108 |
+
|
| 109 |
+
# Softmax over heads (not over documents) — each head contributes a
|
| 110 |
+
# weighted value, and the weighting is query-key dependent.
|
| 111 |
+
attn_weights = F.softmax(attn, dim=-1) # (batch, num_docs, num_heads)
|
| 112 |
+
|
| 113 |
+
# Expand q for value MLPs: (batch, num_docs, embed_dim)
|
| 114 |
+
q_expanded = q.unsqueeze(1).expand(-1, num_docs, -1)
|
| 115 |
+
|
| 116 |
+
# Compute per-head values and weight them
|
| 117 |
+
total = torch.zeros(batch_size, num_docs, device=q.device)
|
| 118 |
+
for h in range(self.num_heads):
|
| 119 |
+
v_h = self.value_mlps[h](q_expanded, d) # (batch, num_docs)
|
| 120 |
+
total = total + attn_weights[:, :, h] * v_h
|
| 121 |
+
|
| 122 |
+
if single_doc:
|
| 123 |
+
total = total.squeeze(1)
|
| 124 |
+
|
| 125 |
+
return total
|
| 126 |
+
|
| 127 |
+
def compute_pairwise(self, q: torch.Tensor,
|
| 128 |
+
docs: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
"""Compute R(q, d) for all query-document pairs.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
q: Query embeddings, shape (num_queries, embedding_dim)
|
| 133 |
+
docs: Document embeddings, shape (num_docs, embedding_dim)
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Relevance matrix, shape (num_queries, num_docs).
|
| 137 |
+
"""
|
| 138 |
+
# Expand docs for each query
|
| 139 |
+
num_queries = q.shape[0]
|
| 140 |
+
num_docs = docs.shape[0]
|
| 141 |
+
docs_expanded = docs.unsqueeze(0).expand(num_queries, -1, -1)
|
| 142 |
+
return self.forward(q, docs_expanded)
|
multi_manifold_retrieval/models/encoders.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper around sentence-transformers for document and query encoders."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DualEncoder(nn.Module):
|
| 9 |
+
"""Frozen pretrained encoders for query and document manifolds.
|
| 10 |
+
|
| 11 |
+
Uses the same pretrained model for both query and document encoding
|
| 12 |
+
(separate manifolds are induced by the cross-manifold operator, not
|
| 13 |
+
by separate encoder weights).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 17 |
+
max_seq_length: int = 128, freeze: bool = True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.model = SentenceTransformer(model_name)
|
| 20 |
+
self.model.max_seq_length = max_seq_length
|
| 21 |
+
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 22 |
+
|
| 23 |
+
if freeze:
|
| 24 |
+
for param in self.model.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
|
| 27 |
+
def encode_queries(self, texts: list[str], batch_size: int = 64,
|
| 28 |
+
show_progress: bool = False) -> torch.Tensor:
|
| 29 |
+
"""Encode query texts to embeddings on M_Q."""
|
| 30 |
+
embeddings = self.model.encode(
|
| 31 |
+
texts, batch_size=batch_size, show_progress_bar=show_progress,
|
| 32 |
+
convert_to_tensor=True, normalize_embeddings=True,
|
| 33 |
+
)
|
| 34 |
+
return embeddings
|
| 35 |
+
|
| 36 |
+
def encode_documents(self, texts: list[str], batch_size: int = 64,
|
| 37 |
+
show_progress: bool = False) -> torch.Tensor:
|
| 38 |
+
"""Encode document texts to embeddings on M_D."""
|
| 39 |
+
embeddings = self.model.encode(
|
| 40 |
+
texts, batch_size=batch_size, show_progress_bar=show_progress,
|
| 41 |
+
convert_to_tensor=True, normalize_embeddings=True,
|
| 42 |
+
)
|
| 43 |
+
return embeddings
|
| 44 |
+
|
| 45 |
+
def forward_queries(self, input_ids: torch.Tensor,
|
| 46 |
+
attention_mask: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""Forward pass for query token IDs (for training)."""
|
| 48 |
+
features = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 49 |
+
out = self.model.forward(features)
|
| 50 |
+
embeddings = out["sentence_embedding"]
|
| 51 |
+
return nn.functional.normalize(embeddings, p=2, dim=-1)
|
| 52 |
+
|
| 53 |
+
def forward_documents(self, input_ids: torch.Tensor,
|
| 54 |
+
attention_mask: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""Forward pass for document token IDs (for training)."""
|
| 56 |
+
return self.forward_queries(input_ids, attention_mask)
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def tokenizer(self):
|
| 60 |
+
return self.model.tokenizer
|
multi_manifold_retrieval/training/__init__.py
ADDED
|
File without changes
|
multi_manifold_retrieval/training/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
multi_manifold_retrieval/training/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
multi_manifold_retrieval/training/__pycache__/losses.cpython-310.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
multi_manifold_retrieval/training/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
multi_manifold_retrieval/training/data.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MS MARCO data loading for training and evaluation."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MSMARCOTripleDataset(Dataset):
|
| 12 |
+
"""MS MARCO passage ranking dataset with hard negatives.
|
| 13 |
+
|
| 14 |
+
Each example yields (query, positive_passage, [negative_passages]).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, tokenizer, max_samples: int = 100_000,
|
| 18 |
+
num_negatives: int = 7, max_seq_length: int = 128,
|
| 19 |
+
split: str = "train", seed: int = 42):
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
self.max_seq_length = max_seq_length
|
| 22 |
+
self.num_negatives = num_negatives
|
| 23 |
+
|
| 24 |
+
# Load MS MARCO dataset
|
| 25 |
+
print(f"Loading MS MARCO ({split} split, max {max_samples} samples)...")
|
| 26 |
+
dataset = load_dataset("ms_marco", "v2.1", split=split, trust_remote_code=True)
|
| 27 |
+
|
| 28 |
+
# Filter to examples with at least one selected passage
|
| 29 |
+
self.examples = []
|
| 30 |
+
for i, ex in enumerate(dataset):
|
| 31 |
+
if len(self.examples) >= max_samples:
|
| 32 |
+
break
|
| 33 |
+
passages = ex["passages"]
|
| 34 |
+
selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
|
| 35 |
+
if selected:
|
| 36 |
+
self.examples.append({
|
| 37 |
+
"query": ex["query"],
|
| 38 |
+
"positive": passages["passage_text"][selected[0]],
|
| 39 |
+
"negatives": [
|
| 40 |
+
passages["passage_text"][j]
|
| 41 |
+
for j in range(len(passages["passage_text"]))
|
| 42 |
+
if j not in selected
|
| 43 |
+
],
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
print(f"Loaded {len(self.examples)} training examples.")
|
| 47 |
+
self.rng = random.Random(seed)
|
| 48 |
+
|
| 49 |
+
def __len__(self) -> int:
|
| 50 |
+
return len(self.examples)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx: int) -> dict:
|
| 53 |
+
ex = self.examples[idx]
|
| 54 |
+
# Sample negatives (from in-passage negatives, pad with random if needed)
|
| 55 |
+
available_negs = ex["negatives"]
|
| 56 |
+
if len(available_negs) >= self.num_negatives:
|
| 57 |
+
negs = self.rng.sample(available_negs, self.num_negatives)
|
| 58 |
+
else:
|
| 59 |
+
negs = available_negs[:]
|
| 60 |
+
# Pad with random negatives from other examples
|
| 61 |
+
while len(negs) < self.num_negatives:
|
| 62 |
+
rand_ex = self.examples[self.rng.randint(0, len(self.examples) - 1)]
|
| 63 |
+
if rand_ex["positive"] != ex["positive"]:
|
| 64 |
+
negs.append(rand_ex["positive"])
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"query": ex["query"],
|
| 68 |
+
"positive": ex["positive"],
|
| 69 |
+
"negatives": negs,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def collate_fn(batch: list[dict], tokenizer, max_seq_length: int = 128) -> dict:
|
| 74 |
+
"""Collate batch into tokenized tensors."""
|
| 75 |
+
queries = [b["query"] for b in batch]
|
| 76 |
+
positives = [b["positive"] for b in batch]
|
| 77 |
+
all_negatives = []
|
| 78 |
+
for b in batch:
|
| 79 |
+
all_negatives.extend(b["negatives"])
|
| 80 |
+
|
| 81 |
+
# Tokenize
|
| 82 |
+
q_enc = tokenizer(
|
| 83 |
+
queries, padding=True, truncation=True,
|
| 84 |
+
max_length=max_seq_length, return_tensors="pt",
|
| 85 |
+
)
|
| 86 |
+
p_enc = tokenizer(
|
| 87 |
+
positives, padding=True, truncation=True,
|
| 88 |
+
max_length=max_seq_length, return_tensors="pt",
|
| 89 |
+
)
|
| 90 |
+
n_enc = tokenizer(
|
| 91 |
+
all_negatives, padding=True, truncation=True,
|
| 92 |
+
max_length=max_seq_length, return_tensors="pt",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
num_negatives = len(batch[0]["negatives"])
|
| 96 |
+
return {
|
| 97 |
+
"query_input_ids": q_enc["input_ids"],
|
| 98 |
+
"query_attention_mask": q_enc["attention_mask"],
|
| 99 |
+
"pos_input_ids": p_enc["input_ids"],
|
| 100 |
+
"pos_attention_mask": p_enc["attention_mask"],
|
| 101 |
+
"neg_input_ids": n_enc["input_ids"],
|
| 102 |
+
"neg_attention_mask": n_enc["attention_mask"],
|
| 103 |
+
"num_negatives": num_negatives,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_dataloader(tokenizer, max_samples: int = 100_000,
|
| 108 |
+
num_negatives: int = 7, batch_size: int = 64,
|
| 109 |
+
max_seq_length: int = 128, split: str = "train",
|
| 110 |
+
seed: int = 42, num_workers: int = 0) -> DataLoader:
|
| 111 |
+
"""Create a DataLoader for MS MARCO training."""
|
| 112 |
+
dataset = MSMARCOTripleDataset(
|
| 113 |
+
tokenizer=tokenizer, max_samples=max_samples,
|
| 114 |
+
num_negatives=num_negatives, max_seq_length=max_seq_length,
|
| 115 |
+
split=split, seed=seed,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _collate(batch):
|
| 119 |
+
return collate_fn(batch, tokenizer, max_seq_length)
|
| 120 |
+
|
| 121 |
+
return DataLoader(
|
| 122 |
+
dataset, batch_size=batch_size, shuffle=True,
|
| 123 |
+
collate_fn=_collate, num_workers=num_workers,
|
| 124 |
+
drop_last=True,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class MSMARCOEvalDataset:
|
| 129 |
+
"""MS MARCO dev set for evaluation."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, tokenizer, max_queries: int = 5000,
|
| 132 |
+
max_seq_length: int = 128, seed: int = 42):
|
| 133 |
+
self.tokenizer = tokenizer
|
| 134 |
+
self.max_seq_length = max_seq_length
|
| 135 |
+
|
| 136 |
+
print(f"Loading MS MARCO dev set (max {max_queries} queries)...")
|
| 137 |
+
dataset = load_dataset("ms_marco", "v2.1", split="validation", trust_remote_code=True)
|
| 138 |
+
|
| 139 |
+
self.queries = []
|
| 140 |
+
self.positives = [] # list of list of positive passage texts
|
| 141 |
+
self.all_passages = [] # flat list of all passages for retrieval
|
| 142 |
+
self.passage_set = set()
|
| 143 |
+
|
| 144 |
+
rng = random.Random(seed)
|
| 145 |
+
indices = list(range(len(dataset)))
|
| 146 |
+
rng.shuffle(indices)
|
| 147 |
+
|
| 148 |
+
for i in indices:
|
| 149 |
+
if len(self.queries) >= max_queries:
|
| 150 |
+
break
|
| 151 |
+
ex = dataset[i]
|
| 152 |
+
passages = ex["passages"]
|
| 153 |
+
selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
|
| 154 |
+
if not selected:
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
self.queries.append(ex["query"])
|
| 158 |
+
pos_texts = [passages["passage_text"][j] for j in selected]
|
| 159 |
+
self.positives.append(pos_texts)
|
| 160 |
+
|
| 161 |
+
# Add all passages to the corpus
|
| 162 |
+
for text in passages["passage_text"]:
|
| 163 |
+
if text not in self.passage_set:
|
| 164 |
+
self.passage_set.add(text)
|
| 165 |
+
self.all_passages.append(text)
|
| 166 |
+
|
| 167 |
+
print(f"Loaded {len(self.queries)} eval queries, "
|
| 168 |
+
f"{len(self.all_passages)} unique passages.")
|
multi_manifold_retrieval/training/losses.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contrastive loss for training the cross-manifold operator."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ContrastiveLoss(nn.Module):
|
| 9 |
+
"""Cross-entropy contrastive loss over (query, positive, negatives) triples.
|
| 10 |
+
|
| 11 |
+
Given a query q, positive document d+, and negative documents d1-, ..., dK-,
|
| 12 |
+
the loss is:
|
| 13 |
+
-log(exp(R(q, d+)/tau) / (exp(R(q, d+)/tau) + sum_k exp(R(q, dk-)/tau)))
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, temperature: float = 0.05):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.temperature = temperature
|
| 19 |
+
|
| 20 |
+
def forward(self, pos_scores: torch.Tensor,
|
| 21 |
+
neg_scores: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
"""Compute contrastive loss.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
pos_scores: (batch_size,) — R(q, d+) for each query.
|
| 26 |
+
neg_scores: (batch_size, num_negatives) — R(q, dk-) for each query.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Scalar loss.
|
| 30 |
+
"""
|
| 31 |
+
# Concatenate: (batch_size, 1 + num_negatives)
|
| 32 |
+
all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
|
| 33 |
+
all_scores = all_scores / self.temperature
|
| 34 |
+
|
| 35 |
+
# Target: index 0 is the positive
|
| 36 |
+
targets = torch.zeros(all_scores.shape[0], dtype=torch.long,
|
| 37 |
+
device=all_scores.device)
|
| 38 |
+
return F.cross_entropy(all_scores, targets)
|
multi_manifold_retrieval/training/train.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training loop for the cross-manifold operator."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 11 |
+
|
| 12 |
+
from multi_manifold_retrieval.models.encoders import DualEncoder
|
| 13 |
+
from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator
|
| 14 |
+
from multi_manifold_retrieval.training.data import get_dataloader
|
| 15 |
+
from multi_manifold_retrieval.training.losses import ContrastiveLoss
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def train(config_path: str = "configs/default.yaml",
|
| 19 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
| 20 |
+
"""Train the cross-manifold operator on MS MARCO."""
|
| 21 |
+
with open(config_path) as f:
|
| 22 |
+
config = yaml.safe_load(f)
|
| 23 |
+
|
| 24 |
+
torch.manual_seed(config["seed"])
|
| 25 |
+
|
| 26 |
+
# Initialize encoder
|
| 27 |
+
print("Initializing encoder...", flush=True)
|
| 28 |
+
encoder = DualEncoder(
|
| 29 |
+
model_name=config["encoder"]["model_name"],
|
| 30 |
+
max_seq_length=config["training"]["max_seq_length"],
|
| 31 |
+
freeze=config["encoder"]["freeze"],
|
| 32 |
+
)
|
| 33 |
+
embedding_dim = encoder.embedding_dim
|
| 34 |
+
|
| 35 |
+
# Initialize cross-manifold operator
|
| 36 |
+
cm_config = config["cross_manifold"]
|
| 37 |
+
operator = CrossManifoldOperator(
|
| 38 |
+
embedding_dim=embedding_dim,
|
| 39 |
+
num_heads=cm_config["num_heads"],
|
| 40 |
+
value_hidden_dim=cm_config["value_mlp_hidden"],
|
| 41 |
+
value_num_layers=cm_config["value_mlp_layers"],
|
| 42 |
+
dropout=cm_config["dropout"],
|
| 43 |
+
).to(device)
|
| 44 |
+
|
| 45 |
+
print(f"Encoder initialized (dim={embedding_dim}).", flush=True)
|
| 46 |
+
|
| 47 |
+
# Loss
|
| 48 |
+
loss_fn = ContrastiveLoss(temperature=0.05)
|
| 49 |
+
|
| 50 |
+
# Data
|
| 51 |
+
print("Loading training data...", flush=True)
|
| 52 |
+
t0 = time.time()
|
| 53 |
+
train_loader = get_dataloader(
|
| 54 |
+
tokenizer=encoder.tokenizer,
|
| 55 |
+
max_samples=config["training"]["max_train_samples"],
|
| 56 |
+
num_negatives=config["training"]["num_negatives"],
|
| 57 |
+
batch_size=config["training"]["batch_size"],
|
| 58 |
+
max_seq_length=config["training"]["max_seq_length"],
|
| 59 |
+
split="train",
|
| 60 |
+
seed=config["seed"],
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print(f"Data loaded in {time.time()-t0:.1f}s. Batches per epoch: {len(train_loader)}", flush=True)
|
| 64 |
+
|
| 65 |
+
# Optimizer (only train the cross-manifold operator)
|
| 66 |
+
optimizer = AdamW(
|
| 67 |
+
operator.parameters(),
|
| 68 |
+
lr=config["training"]["learning_rate"],
|
| 69 |
+
weight_decay=config["training"]["weight_decay"],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
total_steps = len(train_loader) * config["training"]["epochs"]
|
| 73 |
+
scheduler = OneCycleLR(
|
| 74 |
+
optimizer,
|
| 75 |
+
max_lr=config["training"]["learning_rate"],
|
| 76 |
+
total_steps=total_steps,
|
| 77 |
+
pct_start=min(config["training"]["warmup_steps"] / total_steps, 0.1),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Training loop
|
| 81 |
+
print(f"Moving encoder to {device}...", flush=True)
|
| 82 |
+
encoder.model.to(device)
|
| 83 |
+
encoder.model.eval()
|
| 84 |
+
operator.train()
|
| 85 |
+
print(f"Starting training: {config['training']['epochs']} epochs, {total_steps} total steps", flush=True)
|
| 86 |
+
|
| 87 |
+
save_dir = config["training"]["save_dir"]
|
| 88 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 89 |
+
log_every = config["training"]["log_every"]
|
| 90 |
+
|
| 91 |
+
global_step = 0
|
| 92 |
+
best_loss = float("inf")
|
| 93 |
+
|
| 94 |
+
for epoch in range(config["training"]["epochs"]):
|
| 95 |
+
epoch_loss = 0.0
|
| 96 |
+
epoch_start = time.time()
|
| 97 |
+
|
| 98 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 99 |
+
# Move to device
|
| 100 |
+
q_ids = batch["query_input_ids"].to(device)
|
| 101 |
+
q_mask = batch["query_attention_mask"].to(device)
|
| 102 |
+
p_ids = batch["pos_input_ids"].to(device)
|
| 103 |
+
p_mask = batch["pos_attention_mask"].to(device)
|
| 104 |
+
n_ids = batch["neg_input_ids"].to(device)
|
| 105 |
+
n_mask = batch["neg_attention_mask"].to(device)
|
| 106 |
+
num_neg = batch["num_negatives"]
|
| 107 |
+
|
| 108 |
+
# Encode (no grad for frozen encoder)
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
q_emb = encoder.forward_queries(q_ids, q_mask) # (B, D)
|
| 111 |
+
p_emb = encoder.forward_documents(p_ids, p_mask) # (B, D)
|
| 112 |
+
n_emb = encoder.forward_documents(n_ids, n_mask) # (B*K, D)
|
| 113 |
+
|
| 114 |
+
batch_size = q_emb.shape[0]
|
| 115 |
+
n_emb = n_emb.view(batch_size, num_neg, -1) # (B, K, D)
|
| 116 |
+
|
| 117 |
+
# Compute relevance scores via cross-manifold operator
|
| 118 |
+
pos_scores = operator(q_emb, p_emb) # (B,)
|
| 119 |
+
neg_scores = operator(q_emb, n_emb) # (B, K)
|
| 120 |
+
|
| 121 |
+
# Loss
|
| 122 |
+
loss = loss_fn(pos_scores, neg_scores)
|
| 123 |
+
|
| 124 |
+
# Backward
|
| 125 |
+
optimizer.zero_grad()
|
| 126 |
+
loss.backward()
|
| 127 |
+
torch.nn.utils.clip_grad_norm_(operator.parameters(), 1.0)
|
| 128 |
+
optimizer.step()
|
| 129 |
+
scheduler.step()
|
| 130 |
+
|
| 131 |
+
epoch_loss += loss.item()
|
| 132 |
+
global_step += 1
|
| 133 |
+
|
| 134 |
+
if global_step % log_every == 0:
|
| 135 |
+
avg_loss = epoch_loss / (batch_idx + 1)
|
| 136 |
+
lr = scheduler.get_last_lr()[0]
|
| 137 |
+
print(f" Step {global_step} | Loss: {loss.item():.4f} | "
|
| 138 |
+
f"Avg: {avg_loss:.4f} | LR: {lr:.2e}", flush=True)
|
| 139 |
+
|
| 140 |
+
epoch_time = time.time() - epoch_start
|
| 141 |
+
avg_loss = epoch_loss / len(train_loader)
|
| 142 |
+
print(f"Epoch {epoch+1}/{config['training']['epochs']} | "
|
| 143 |
+
f"Avg Loss: {avg_loss:.4f} | Time: {epoch_time:.1f}s", flush=True)
|
| 144 |
+
|
| 145 |
+
# Save best
|
| 146 |
+
if avg_loss < best_loss:
|
| 147 |
+
best_loss = avg_loss
|
| 148 |
+
torch.save(operator.state_dict(), os.path.join(save_dir, "best_operator.pt"))
|
| 149 |
+
print(f" Saved best model (loss={best_loss:.4f})")
|
| 150 |
+
|
| 151 |
+
# Save final
|
| 152 |
+
torch.save(operator.state_dict(), os.path.join(save_dir, "final_operator.pt"))
|
| 153 |
+
print(f"Training complete. Best loss: {best_loss:.4f}")
|
| 154 |
+
|
| 155 |
+
return encoder, operator
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
train()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
transformers>=4.30
|
| 3 |
+
sentence-transformers>=2.2
|
| 4 |
+
datasets>=2.14
|
| 5 |
+
faiss-cpu>=1.7
|
| 6 |
+
numpy>=1.24
|
| 7 |
+
scipy>=1.10
|
| 8 |
+
scikit-learn>=1.3
|
| 9 |
+
pyyaml>=6.0
|
| 10 |
+
tqdm>=4.65
|
results.json
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"config": {
|
| 3 |
+
"seed": 42,
|
| 4 |
+
"encoder": {
|
| 5 |
+
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
| 6 |
+
"embedding_dim": 384,
|
| 7 |
+
"freeze": true
|
| 8 |
+
},
|
| 9 |
+
"cross_manifold": {
|
| 10 |
+
"num_heads": 4,
|
| 11 |
+
"head_dim": 96,
|
| 12 |
+
"value_mlp_hidden": 256,
|
| 13 |
+
"value_mlp_layers": 2,
|
| 14 |
+
"dropout": 0.1
|
| 15 |
+
},
|
| 16 |
+
"training": {
|
| 17 |
+
"batch_size": 64,
|
| 18 |
+
"learning_rate": 0.0002,
|
| 19 |
+
"weight_decay": 0.01,
|
| 20 |
+
"epochs": 5,
|
| 21 |
+
"warmup_steps": 500,
|
| 22 |
+
"max_train_samples": 100000,
|
| 23 |
+
"num_negatives": 7,
|
| 24 |
+
"max_seq_length": 128,
|
| 25 |
+
"fp16": true,
|
| 26 |
+
"gradient_accumulation_steps": 1,
|
| 27 |
+
"log_every": 100,
|
| 28 |
+
"eval_every": 2000,
|
| 29 |
+
"save_dir": "checkpoints"
|
| 30 |
+
},
|
| 31 |
+
"evaluation": {
|
| 32 |
+
"max_eval_queries": 5000,
|
| 33 |
+
"metrics": [
|
| 34 |
+
"mrr@10",
|
| 35 |
+
"recall@100"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
"spectral": {
|
| 39 |
+
"num_documents": 1000,
|
| 40 |
+
"num_queries": 500,
|
| 41 |
+
"k_neighbors": 20
|
| 42 |
+
},
|
| 43 |
+
"attack": {
|
| 44 |
+
"target_domain": "medical",
|
| 45 |
+
"num_target_queries": 100,
|
| 46 |
+
"top_k": 10,
|
| 47 |
+
"medical_keywords": [
|
| 48 |
+
"health",
|
| 49 |
+
"medical",
|
| 50 |
+
"doctor",
|
| 51 |
+
"patient",
|
| 52 |
+
"treatment",
|
| 53 |
+
"disease",
|
| 54 |
+
"symptom",
|
| 55 |
+
"diagnosis",
|
| 56 |
+
"medicine",
|
| 57 |
+
"clinical"
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
"device": "cuda",
|
| 62 |
+
"retrieval_multi_manifold": {
|
| 63 |
+
"mrr@10": 0.6002776984126992,
|
| 64 |
+
"recall@100": 0.9901,
|
| 65 |
+
"num_queries": 5000
|
| 66 |
+
},
|
| 67 |
+
"retrieval_baseline": {
|
| 68 |
+
"mrr@10": 0.5828701587301599,
|
| 69 |
+
"recall@100": 0.9942,
|
| 70 |
+
"num_queries": 5000
|
| 71 |
+
},
|
| 72 |
+
"mrr_ratio": 1.0298652099816936,
|
| 73 |
+
"spectral": {
|
| 74 |
+
"multi_manifold": {
|
| 75 |
+
"spectral_discrepancy": 0.05735765351097603,
|
| 76 |
+
"fiedler_alignment": 0.03973450829139222
|
| 77 |
+
},
|
| 78 |
+
"baseline": {
|
| 79 |
+
"spectral_discrepancy": 0.22395483470893326,
|
| 80 |
+
"fiedler_alignment": 0.7848795751112227
|
| 81 |
+
},
|
| 82 |
+
"num_documents": 1000,
|
| 83 |
+
"num_queries": 500
|
| 84 |
+
},
|
| 85 |
+
"attack": {
|
| 86 |
+
"bipartite_attack": {
|
| 87 |
+
"baseline_asr": 0.51,
|
| 88 |
+
"multi_manifold_asr": 0.19,
|
| 89 |
+
"baseline_margins": {
|
| 90 |
+
"mean_margin": -0.00035160839557647706,
|
| 91 |
+
"median_margin": 0.0012769699096679688,
|
| 92 |
+
"p25_margin": -0.03736262768507004,
|
| 93 |
+
"fraction_positive_margin": 0.51
|
| 94 |
+
},
|
| 95 |
+
"multi_manifold_margins": {
|
| 96 |
+
"mean_margin": -0.0419993931055069,
|
| 97 |
+
"median_margin": -0.04468509554862976,
|
| 98 |
+
"p25_margin": -0.06919527053833008,
|
| 99 |
+
"fraction_positive_margin": 0.19
|
| 100 |
+
},
|
| 101 |
+
"placement_info": {
|
| 102 |
+
"method": "bipartite_fiedler",
|
| 103 |
+
"fiedler_eigenvalue": 0.12380694040243122,
|
| 104 |
+
"weight_entropy": 4.995662314784581,
|
| 105 |
+
"max_weight": 0.016811827898423226,
|
| 106 |
+
"adv_mean_cos_to_queries": 0.2539085502225744,
|
| 107 |
+
"adv_mean_cos_to_docs": 0.2686595998305845
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
"doconly_attack": {
|
| 111 |
+
"baseline_asr": 0.03,
|
| 112 |
+
"multi_manifold_asr": 0.03,
|
| 113 |
+
"baseline_margins": {
|
| 114 |
+
"mean_margin": -0.1627955549955368,
|
| 115 |
+
"median_margin": -0.1696268543601036,
|
| 116 |
+
"p25_margin": -0.21040004305541515,
|
| 117 |
+
"fraction_positive_margin": 0.03
|
| 118 |
+
},
|
| 119 |
+
"multi_manifold_margins": {
|
| 120 |
+
"mean_margin": -0.12464183956384658,
|
| 121 |
+
"median_margin": -0.12341519445180893,
|
| 122 |
+
"p25_margin": -0.1708945743739605,
|
| 123 |
+
"fraction_positive_margin": 0.03
|
| 124 |
+
},
|
| 125 |
+
"placement_info": {
|
| 126 |
+
"method": "doconly_fiedler",
|
| 127 |
+
"fiedler_eigenvalue": 3.8886430263519287
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"num_domain_docs": 200,
|
| 131 |
+
"num_target_queries": 100,
|
| 132 |
+
"top_k": 10,
|
| 133 |
+
"baseline_asr": 0.51,
|
| 134 |
+
"multi_manifold_asr": 0.19
|
| 135 |
+
}
|
| 136 |
+
}
|
run_experiment.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""End-to-end experiment: train → evaluate → spectral analysis → attack simulation."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import yaml
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from multi_manifold_retrieval.models.encoders import DualEncoder
|
| 15 |
+
from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator
|
| 16 |
+
from multi_manifold_retrieval.models.baseline import BaselineOperator
|
| 17 |
+
from multi_manifold_retrieval.training.train import train
|
| 18 |
+
from multi_manifold_retrieval.training.data import MSMARCOEvalDataset
|
| 19 |
+
from multi_manifold_retrieval.evaluation.spectral_analysis import run_spectral_analysis
|
| 20 |
+
from multi_manifold_retrieval.evaluation.retrieval_metrics import compute_retrieval_metrics
|
| 21 |
+
from multi_manifold_retrieval.evaluation.attack_simulation import (
|
| 22 |
+
run_attack_simulation,
|
| 23 |
+
select_domain_documents,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def set_seed(seed: int):
|
| 28 |
+
random.seed(seed)
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
torch.manual_seed(seed)
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
torch.cuda.manual_seed_all(seed)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
parser = argparse.ArgumentParser(description="Multi-Manifold Retrieval PoC Experiment")
|
| 37 |
+
parser.add_argument("--config", type=str, default="configs/default.yaml")
|
| 38 |
+
parser.add_argument("--skip-train", action="store_true", help="Skip training, load from checkpoint")
|
| 39 |
+
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_operator.pt")
|
| 40 |
+
parser.add_argument("--output", type=str, default="results.json")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
with open(args.config) as f:
|
| 44 |
+
config = yaml.safe_load(f)
|
| 45 |
+
|
| 46 |
+
set_seed(config["seed"])
|
| 47 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 48 |
+
print(f"Using device: {device}")
|
| 49 |
+
|
| 50 |
+
results = {"config": config, "device": device}
|
| 51 |
+
|
| 52 |
+
# =========================================================================
|
| 53 |
+
# Phase 1: Training
|
| 54 |
+
# =========================================================================
|
| 55 |
+
print("\n" + "=" * 60)
|
| 56 |
+
print("PHASE 1: TRAINING")
|
| 57 |
+
print("=" * 60)
|
| 58 |
+
|
| 59 |
+
if args.skip_train and os.path.exists(args.checkpoint):
|
| 60 |
+
print(f"Loading encoder and operator from checkpoint: {args.checkpoint}")
|
| 61 |
+
encoder = DualEncoder(
|
| 62 |
+
model_name=config["encoder"]["model_name"],
|
| 63 |
+
max_seq_length=config["training"]["max_seq_length"],
|
| 64 |
+
freeze=config["encoder"]["freeze"],
|
| 65 |
+
)
|
| 66 |
+
cm_config = config["cross_manifold"]
|
| 67 |
+
operator = CrossManifoldOperator(
|
| 68 |
+
embedding_dim=encoder.embedding_dim,
|
| 69 |
+
num_heads=cm_config["num_heads"],
|
| 70 |
+
value_hidden_dim=cm_config["value_mlp_hidden"],
|
| 71 |
+
value_num_layers=cm_config["value_mlp_layers"],
|
| 72 |
+
dropout=cm_config["dropout"],
|
| 73 |
+
)
|
| 74 |
+
operator.load_state_dict(torch.load(args.checkpoint, map_location=device, weights_only=True))
|
| 75 |
+
operator.to(device)
|
| 76 |
+
else:
|
| 77 |
+
encoder, operator = train(config_path=args.config, device=device)
|
| 78 |
+
|
| 79 |
+
encoder.model.to(device)
|
| 80 |
+
operator.to(device)
|
| 81 |
+
baseline_operator = BaselineOperator().to(device)
|
| 82 |
+
|
| 83 |
+
# =========================================================================
|
| 84 |
+
# Phase 2: Evaluation Data Preparation
|
| 85 |
+
# =========================================================================
|
| 86 |
+
print("\n" + "=" * 60)
|
| 87 |
+
print("PHASE 2: EVALUATION DATA PREPARATION")
|
| 88 |
+
print("=" * 60)
|
| 89 |
+
|
| 90 |
+
eval_data = MSMARCOEvalDataset(
|
| 91 |
+
tokenizer=encoder.tokenizer,
|
| 92 |
+
max_queries=config["evaluation"]["max_eval_queries"],
|
| 93 |
+
max_seq_length=config["training"]["max_seq_length"],
|
| 94 |
+
seed=config["seed"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Encode all evaluation passages
|
| 98 |
+
print(f"Encoding {len(eval_data.all_passages)} passages...")
|
| 99 |
+
passage_embeddings = encoder.encode_documents(
|
| 100 |
+
eval_data.all_passages, batch_size=128, show_progress=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Encode evaluation queries
|
| 104 |
+
print(f"Encoding {len(eval_data.queries)} queries...")
|
| 105 |
+
query_embeddings = encoder.encode_queries(
|
| 106 |
+
eval_data.queries, batch_size=128, show_progress=True,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# =========================================================================
|
| 110 |
+
# Phase 3: Retrieval Quality Evaluation
|
| 111 |
+
# =========================================================================
|
| 112 |
+
print("\n" + "=" * 60)
|
| 113 |
+
print("PHASE 3: RETRIEVAL QUALITY")
|
| 114 |
+
print("=" * 60)
|
| 115 |
+
|
| 116 |
+
print("\n--- Multi-Manifold Model ---")
|
| 117 |
+
metrics_mm = compute_retrieval_metrics(
|
| 118 |
+
query_embeddings=query_embeddings,
|
| 119 |
+
doc_embeddings=passage_embeddings,
|
| 120 |
+
operator=operator,
|
| 121 |
+
query_texts=eval_data.queries,
|
| 122 |
+
positive_passages=eval_data.positives,
|
| 123 |
+
all_passages=eval_data.all_passages,
|
| 124 |
+
passage_embeddings=passage_embeddings,
|
| 125 |
+
device=device,
|
| 126 |
+
)
|
| 127 |
+
results["retrieval_multi_manifold"] = metrics_mm
|
| 128 |
+
|
| 129 |
+
print("\n--- Baseline (Cosine Similarity) ---")
|
| 130 |
+
metrics_base = compute_retrieval_metrics(
|
| 131 |
+
query_embeddings=query_embeddings,
|
| 132 |
+
doc_embeddings=passage_embeddings,
|
| 133 |
+
operator=baseline_operator,
|
| 134 |
+
query_texts=eval_data.queries,
|
| 135 |
+
positive_passages=eval_data.positives,
|
| 136 |
+
all_passages=eval_data.all_passages,
|
| 137 |
+
passage_embeddings=passage_embeddings,
|
| 138 |
+
device=device,
|
| 139 |
+
)
|
| 140 |
+
results["retrieval_baseline"] = metrics_base
|
| 141 |
+
|
| 142 |
+
# Check: multi-manifold within 80% of baseline
|
| 143 |
+
if metrics_base["mrr@10"] > 0:
|
| 144 |
+
ratio = metrics_mm["mrr@10"] / metrics_base["mrr@10"]
|
| 145 |
+
print(f"\nMRR@10 ratio (mm/baseline): {ratio:.4f} "
|
| 146 |
+
f"({'PASS' if ratio >= 0.8 else 'BELOW TARGET'}, target >= 0.8)")
|
| 147 |
+
results["mrr_ratio"] = ratio
|
| 148 |
+
|
| 149 |
+
# =========================================================================
|
| 150 |
+
# Phase 4: Spectral Analysis
|
| 151 |
+
# =========================================================================
|
| 152 |
+
print("\n" + "=" * 60)
|
| 153 |
+
print("PHASE 4: SPECTRAL ANALYSIS")
|
| 154 |
+
print("=" * 60)
|
| 155 |
+
|
| 156 |
+
# Sample documents for spectral analysis
|
| 157 |
+
num_spectral_docs = min(config["spectral"]["num_documents"], len(eval_data.all_passages))
|
| 158 |
+
num_spectral_queries = min(config["spectral"]["num_queries"], len(eval_data.queries))
|
| 159 |
+
|
| 160 |
+
spectral_doc_indices = np.random.choice(
|
| 161 |
+
len(eval_data.all_passages), num_spectral_docs, replace=False
|
| 162 |
+
)
|
| 163 |
+
spectral_query_indices = np.random.choice(
|
| 164 |
+
len(eval_data.queries), num_spectral_queries, replace=False
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
spectral_doc_emb_np = passage_embeddings[spectral_doc_indices].cpu().numpy()
|
| 168 |
+
spectral_doc_emb_torch = passage_embeddings[spectral_doc_indices]
|
| 169 |
+
spectral_query_emb_torch = query_embeddings[spectral_query_indices]
|
| 170 |
+
|
| 171 |
+
spectral_results = run_spectral_analysis(
|
| 172 |
+
doc_embeddings_np=spectral_doc_emb_np,
|
| 173 |
+
doc_embeddings_torch=spectral_doc_emb_torch,
|
| 174 |
+
query_embeddings_torch=spectral_query_emb_torch,
|
| 175 |
+
operator=operator,
|
| 176 |
+
baseline_operator=baseline_operator,
|
| 177 |
+
device=device,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
results["spectral"] = {
|
| 181 |
+
"multi_manifold": spectral_results["multi_manifold"],
|
| 182 |
+
"baseline": spectral_results["baseline"],
|
| 183 |
+
"num_documents": num_spectral_docs,
|
| 184 |
+
"num_queries": num_spectral_queries,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# =========================================================================
|
| 188 |
+
# Phase 5: Attack Simulation
|
| 189 |
+
# =========================================================================
|
| 190 |
+
print("\n" + "=" * 60)
|
| 191 |
+
print("PHASE 5: ATTACK SIMULATION")
|
| 192 |
+
print("=" * 60)
|
| 193 |
+
|
| 194 |
+
attack_config = config["attack"]
|
| 195 |
+
|
| 196 |
+
# Select target queries (medical domain)
|
| 197 |
+
target_queries = []
|
| 198 |
+
for q in eval_data.queries:
|
| 199 |
+
q_lower = q.lower()
|
| 200 |
+
if any(kw in q_lower for kw in attack_config["medical_keywords"]):
|
| 201 |
+
target_queries.append(q)
|
| 202 |
+
if len(target_queries) >= attack_config["num_target_queries"]:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
if len(target_queries) < 10:
|
| 206 |
+
# Fall back: use random queries if not enough medical ones
|
| 207 |
+
print(f"Only found {len(target_queries)} medical queries; "
|
| 208 |
+
f"using random queries to reach {attack_config['num_target_queries']}.")
|
| 209 |
+
remaining = attack_config["num_target_queries"] - len(target_queries)
|
| 210 |
+
other_queries = [q for q in eval_data.queries if q not in target_queries]
|
| 211 |
+
target_queries.extend(random.sample(other_queries, min(remaining, len(other_queries))))
|
| 212 |
+
|
| 213 |
+
print(f"Using {len(target_queries)} target queries for attack simulation.")
|
| 214 |
+
|
| 215 |
+
attack_results = run_attack_simulation(
|
| 216 |
+
encoder=encoder,
|
| 217 |
+
operator=operator,
|
| 218 |
+
baseline_operator=baseline_operator,
|
| 219 |
+
passages=eval_data.all_passages,
|
| 220 |
+
passage_embeddings_torch=passage_embeddings,
|
| 221 |
+
target_query_texts=target_queries,
|
| 222 |
+
medical_keywords=attack_config["medical_keywords"],
|
| 223 |
+
top_k=attack_config["top_k"],
|
| 224 |
+
device=device,
|
| 225 |
+
)
|
| 226 |
+
results["attack"] = attack_results
|
| 227 |
+
|
| 228 |
+
# =========================================================================
|
| 229 |
+
# Summary
|
| 230 |
+
# =========================================================================
|
| 231 |
+
print("\n" + "=" * 60)
|
| 232 |
+
print("EXPERIMENT SUMMARY")
|
| 233 |
+
print("=" * 60)
|
| 234 |
+
|
| 235 |
+
print(f"\n1. Retrieval Quality:")
|
| 236 |
+
print(f" Baseline MRR@10: {metrics_base['mrr@10']:.4f}")
|
| 237 |
+
print(f" Multi-Manifold MRR@10: {metrics_mm['mrr@10']:.4f}")
|
| 238 |
+
if metrics_base["mrr@10"] > 0:
|
| 239 |
+
print(f" Ratio: {metrics_mm['mrr@10']/metrics_base['mrr@10']:.4f}")
|
| 240 |
+
|
| 241 |
+
print(f"\n2. Spectral Analysis:")
|
| 242 |
+
print(f" Baseline δ: {spectral_results['baseline']['spectral_discrepancy']:.4f}")
|
| 243 |
+
print(f" Multi-Manifold δ: {spectral_results['multi_manifold']['spectral_discrepancy']:.4f}")
|
| 244 |
+
print(f" Baseline cos(θ): {spectral_results['baseline']['fiedler_alignment']:.4f}")
|
| 245 |
+
print(f" Multi-Manifold cos(θ): {spectral_results['multi_manifold']['fiedler_alignment']:.4f}")
|
| 246 |
+
|
| 247 |
+
if "error" not in attack_results:
|
| 248 |
+
print(f"\n3. Attack Simulation:")
|
| 249 |
+
print(f" Baseline ASR@{attack_config['top_k']}: {attack_results['baseline_asr']:.4f}")
|
| 250 |
+
print(f" Multi-Manifold ASR@{attack_config['top_k']}: {attack_results['multi_manifold_asr']:.4f}")
|
| 251 |
+
|
| 252 |
+
# Save results (exclude numpy arrays)
|
| 253 |
+
save_results = {k: v for k, v in results.items()
|
| 254 |
+
if k not in ("L_D", "L_R_mm", "L_R_base")}
|
| 255 |
+
with open(args.output, "w") as f:
|
| 256 |
+
json.dump(save_results, f, indent=2, default=str)
|
| 257 |
+
print(f"\nResults saved to {args.output}")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
main()
|