bragee commited on
Commit
b464490
·
verified ·
1 Parent(s): 938711f

Upload model checkpoints and code

Browse files
Files changed (33) hide show
  1. README.md +69 -3
  2. checkpoints/best_operator.pt +3 -0
  3. checkpoints/final_operator.pt +3 -0
  4. configs/default.yaml +63 -0
  5. multi_manifold_retrieval/__init__.py +0 -0
  6. multi_manifold_retrieval/__pycache__/__init__.cpython-310.pyc +0 -0
  7. multi_manifold_retrieval/evaluation/__init__.py +0 -0
  8. multi_manifold_retrieval/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  9. multi_manifold_retrieval/evaluation/__pycache__/attack_simulation.cpython-310.pyc +0 -0
  10. multi_manifold_retrieval/evaluation/__pycache__/retrieval_metrics.cpython-310.pyc +0 -0
  11. multi_manifold_retrieval/evaluation/__pycache__/spectral_analysis.cpython-310.pyc +0 -0
  12. multi_manifold_retrieval/evaluation/attack_simulation.py +332 -0
  13. multi_manifold_retrieval/evaluation/retrieval_metrics.py +99 -0
  14. multi_manifold_retrieval/evaluation/spectral_analysis.py +205 -0
  15. multi_manifold_retrieval/models/__init__.py +0 -0
  16. multi_manifold_retrieval/models/__pycache__/__init__.cpython-310.pyc +0 -0
  17. multi_manifold_retrieval/models/__pycache__/baseline.cpython-310.pyc +0 -0
  18. multi_manifold_retrieval/models/__pycache__/cross_manifold_operator.cpython-310.pyc +0 -0
  19. multi_manifold_retrieval/models/__pycache__/encoders.cpython-310.pyc +0 -0
  20. multi_manifold_retrieval/models/baseline.py +37 -0
  21. multi_manifold_retrieval/models/cross_manifold_operator.py +142 -0
  22. multi_manifold_retrieval/models/encoders.py +60 -0
  23. multi_manifold_retrieval/training/__init__.py +0 -0
  24. multi_manifold_retrieval/training/__pycache__/__init__.cpython-310.pyc +0 -0
  25. multi_manifold_retrieval/training/__pycache__/data.cpython-310.pyc +0 -0
  26. multi_manifold_retrieval/training/__pycache__/losses.cpython-310.pyc +0 -0
  27. multi_manifold_retrieval/training/__pycache__/train.cpython-310.pyc +0 -0
  28. multi_manifold_retrieval/training/data.py +168 -0
  29. multi_manifold_retrieval/training/losses.py +38 -0
  30. multi_manifold_retrieval/training/train.py +159 -0
  31. requirements.txt +10 -0
  32. results.json +136 -0
  33. run_experiment.py +261 -0
README.md CHANGED
@@ -1,3 +1,69 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Manifold Retrieval: Proof of Concept
2
+
3
+ A proof-of-concept implementation of the Multi-Manifold Retrieval defense against spectral poisoning attacks (GeoPoison-RAG) on Retrieval-Augmented Generation systems.
4
+
5
+ ## Core Idea
6
+
7
+ Standard RAG systems use a single shared embedding space for queries and documents, making the **document geometry identical to the retrieval geometry**. GeoPoison-RAG exploits this by computing the spectral structure (Fiedler vector) of the document graph Laplacian to find optimal adversarial placement.
8
+
9
+ Multi-Manifold Retrieval **decouples** these geometries by using:
10
+ - Separate query and document manifolds (M_Q and M_D)
11
+ - A non-decomposable cross-manifold relevance operator R(q, d)
12
+
13
+ This breaks the attack because the Laplacian the attacker computes (document space) no longer predicts the Laplacian governing retrieval (cross-manifold).
14
+
15
+ ## Project Structure
16
+
17
+ ```
18
+ multi_manifold_retrieval/
19
+ ├── models/
20
+ │ ├── cross_manifold_operator.py # Construction C: Attention-Geometric Hybrid
21
+ │ ├── encoders.py # Sentence-transformer wrapper
22
+ │ └── baseline.py # Standard cosine similarity baseline
23
+ ├── training/
24
+ │ ├── train.py # Training loop
25
+ │ ├── data.py # MS MARCO data loading
26
+ │ └── losses.py # Contrastive loss
27
+ ├── evaluation/
28
+ │ ├── spectral_analysis.py # L_D, L_R, spectral discrepancy, Fiedler alignment
29
+ │ ├── retrieval_metrics.py # MRR@10, Recall@100
30
+ │ └── attack_simulation.py # GeoPoison-RAG simulation
31
+ proofs/
32
+ ├── proof_theorem_4_3.tex # Spectral Decoupling theorem
33
+ └── proof_theorem_6_1.tex # Query Complexity Lower Bound theorem
34
+ configs/
35
+ └── default.yaml # Hyperparameters
36
+ run_experiment.py # End-to-end pipeline
37
+ ```
38
+
39
+ ## Setup
40
+
41
+ ```bash
42
+ pip install -r requirements.txt
43
+ ```
44
+
45
+ ## Running
46
+
47
+ Full experiment (train + evaluate + spectral analysis + attack):
48
+ ```bash
49
+ python run_experiment.py --config configs/default.yaml
50
+ ```
51
+
52
+ Skip training and load from checkpoint:
53
+ ```bash
54
+ python run_experiment.py --skip-train --checkpoint checkpoints/best_operator.pt
55
+ ```
56
+
57
+ ## Key Metrics
58
+
59
+ | Metric | Baseline (expected) | Multi-Manifold (expected) |
60
+ |--------|-------------------|--------------------------|
61
+ | Spectral discrepancy δ | ≈ 0 | > 0 (significant) |
62
+ | Fiedler alignment cos(θ) | ≈ 1 | < 0.5 |
63
+ | ASR@10 | > 0.8 | Significantly lower |
64
+ | MRR@10 | Reference | ≥ 80% of baseline |
65
+
66
+ ## Formal Proofs
67
+
68
+ - `proofs/proof_theorem_4_3.tex`: Proves that non-decomposable R with positive cross-manifold curvature guarantees spectral decoupling δ ≥ Ω(κ_R · λ_2(L_D)).
69
+ - `proofs/proof_theorem_6_1.tex`: Proves that an adaptive adversary needs Ω(Vol(M_Q) / V_{d_Q}(ε/κ_R)) oracle queries to reconstruct R.
checkpoints/best_operator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c63df76c2941f40d9f840057dcdc54d506b74cba5dc71d329117c965fdef783
3
+ size 6969493
checkpoints/final_operator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa9be37ddd7e178a56f93eeb0a3259b8409399ee8be89118d84b53b8533318b1
3
+ size 6969525
configs/default.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Manifold Retrieval - Default Configuration
2
+
3
+ seed: 42
4
+
5
+ # Encoder settings
6
+ encoder:
7
+ model_name: "sentence-transformers/all-MiniLM-L6-v2"
8
+ embedding_dim: 384
9
+ freeze: true # Freeze pretrained encoders
10
+
11
+ # Cross-manifold operator (Construction C)
12
+ cross_manifold:
13
+ num_heads: 4
14
+ head_dim: 96 # embedding_dim / num_heads
15
+ value_mlp_hidden: 256
16
+ value_mlp_layers: 2
17
+ dropout: 0.1
18
+
19
+ # Training
20
+ training:
21
+ batch_size: 64
22
+ learning_rate: 2.0e-4
23
+ weight_decay: 1.0e-2
24
+ epochs: 5
25
+ warmup_steps: 500
26
+ max_train_samples: 100000
27
+ num_negatives: 7
28
+ max_seq_length: 128
29
+ fp16: true
30
+ gradient_accumulation_steps: 1
31
+ log_every: 100
32
+ eval_every: 2000
33
+ save_dir: "checkpoints"
34
+
35
+ # Evaluation
36
+ evaluation:
37
+ max_eval_queries: 5000
38
+ metrics:
39
+ - mrr@10
40
+ - recall@100
41
+
42
+ # Spectral analysis
43
+ spectral:
44
+ num_documents: 1000
45
+ num_queries: 500
46
+ k_neighbors: 20 # For sparse Laplacian (optional)
47
+
48
+ # Attack simulation
49
+ attack:
50
+ target_domain: "medical"
51
+ num_target_queries: 100
52
+ top_k: 10
53
+ medical_keywords:
54
+ - "health"
55
+ - "medical"
56
+ - "doctor"
57
+ - "patient"
58
+ - "treatment"
59
+ - "disease"
60
+ - "symptom"
61
+ - "diagnosis"
62
+ - "medicine"
63
+ - "clinical"
multi_manifold_retrieval/__init__.py ADDED
File without changes
multi_manifold_retrieval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
multi_manifold_retrieval/evaluation/__init__.py ADDED
File without changes
multi_manifold_retrieval/evaluation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
multi_manifold_retrieval/evaluation/__pycache__/attack_simulation.cpython-310.pyc ADDED
Binary file (9.22 kB). View file
 
multi_manifold_retrieval/evaluation/__pycache__/retrieval_metrics.cpython-310.pyc ADDED
Binary file (2.68 kB). View file
 
multi_manifold_retrieval/evaluation/__pycache__/spectral_analysis.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
multi_manifold_retrieval/evaluation/attack_simulation.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simplified GeoPoison-RAG attack simulation.
2
+
3
+ Realistic threat model (matching GeoPoison-RAG Phase 1):
4
+ - Attacker has shadow queries approximating target query distribution.
5
+ - Attacker has access to document embeddings.
6
+ - Attacker builds bipartite query-document graph using COSINE SIMILARITY
7
+ (their model of how retrieval works).
8
+ - Attacker computes Fiedler vector and places adversarial doc at the
9
+ spectral-optimal position in document space.
10
+
11
+ Defense argument:
12
+ - Baseline (cosine sim): attacker's model is correct → high ASR.
13
+ - Multi-manifold (R(q,d)): attacker's model is wrong because R ≠ cosine → lower ASR.
14
+ """
15
+
16
+ import numpy as np
17
+ import torch
18
+ from scipy.sparse.linalg import eigsh
19
+ from sklearn.metrics.pairwise import cosine_similarity
20
+
21
+ from multi_manifold_retrieval.evaluation.spectral_analysis import compute_document_laplacian
22
+
23
+
24
+ def select_domain_documents(
25
+ passages: list[str],
26
+ keywords: list[str],
27
+ max_docs: int = 200,
28
+ ) -> tuple[list[int], list[str]]:
29
+ """Select documents belonging to a target domain by keyword matching."""
30
+ indices = []
31
+ texts = []
32
+ for i, text in enumerate(passages):
33
+ text_lower = text.lower()
34
+ if any(kw in text_lower for kw in keywords):
35
+ indices.append(i)
36
+ texts.append(text)
37
+ if len(indices) >= max_docs:
38
+ break
39
+ return indices, texts
40
+
41
+
42
+ def build_bipartite_fiedler_placement(
43
+ query_embs: np.ndarray,
44
+ doc_embs: np.ndarray,
45
+ t_nn: int = 20,
46
+ ) -> tuple[np.ndarray, dict]:
47
+ """GeoPoison-RAG Phase 1: bipartite spectral placement (cosine-based).
48
+
49
+ The attacker:
50
+ 1. Builds bipartite query-document graph using cosine similarity.
51
+ 2. Computes Fiedler vector of the normalized Laplacian.
52
+ 3. Extracts document component of Fiedler vector.
53
+ 4. Places adversarial doc at Fiedler-weighted centroid of documents.
54
+
55
+ The placement is in DOCUMENT SPACE — the attacker optimizes where to
56
+ place a document, guided by the query-document spectral structure.
57
+ But the attacker assumes retrieval = cosine similarity.
58
+ """
59
+ nq = query_embs.shape[0]
60
+ nd = doc_embs.shape[0]
61
+
62
+ # Cosine similarity between queries and documents (attacker's model)
63
+ S = cosine_similarity(query_embs, doc_embs) # (nq, nd)
64
+
65
+ # Sparsify: keep top-t per query
66
+ t = min(t_nn, nd - 1)
67
+ S_sparse = np.zeros_like(S)
68
+ for i in range(nq):
69
+ top_idx = np.argpartition(S[i], -t)[-t:]
70
+ S_sparse[i, top_idx] = S[i, top_idx]
71
+
72
+ # Build bipartite adjacency: A = [[0, S], [S^T, 0]]
73
+ n = nq + nd
74
+ A = np.zeros((n, n))
75
+ A[:nq, nq:] = S_sparse
76
+ A[nq:, :nq] = S_sparse.T
77
+
78
+ # Normalized Laplacian: L = I - D^{-1/2} A D^{-1/2}
79
+ degrees = A.sum(axis=1)
80
+ degrees[degrees == 0] = 1.0
81
+ D_inv_sqrt = np.diag(1.0 / np.sqrt(degrees))
82
+ L = np.eye(n) - D_inv_sqrt @ A @ D_inv_sqrt
83
+
84
+ # Fiedler vector (2nd smallest eigenvector)
85
+ k = min(3, n - 1)
86
+ eigenvalues, eigenvectors = eigsh(L, k=k, which="SM")
87
+ sorted_idx = np.argsort(eigenvalues)
88
+ fiedler_vec = eigenvectors[:, sorted_idx[1]]
89
+ fiedler_val = eigenvalues[sorted_idx[1]]
90
+
91
+ # Extract document component and use as weights
92
+ doc_component = fiedler_vec[nq:]
93
+ weights = np.abs(doc_component)
94
+ weights = weights / (weights.sum() + 1e-12)
95
+
96
+ # Fiedler-weighted centroid of documents
97
+ adv_embedding = (weights[:, None] * doc_embs).sum(axis=0)
98
+
99
+ # L2-normalize
100
+ norm = np.linalg.norm(adv_embedding)
101
+ if norm > 0:
102
+ adv_embedding = adv_embedding / norm
103
+
104
+ info = {
105
+ "method": "bipartite_fiedler",
106
+ "fiedler_eigenvalue": float(fiedler_val),
107
+ "weight_entropy": float(-np.sum(weights * np.log(weights + 1e-12))),
108
+ "max_weight": float(weights.max()),
109
+ "adv_mean_cos_to_queries": float(
110
+ cosine_similarity(adv_embedding.reshape(1, -1), query_embs).mean()
111
+ ),
112
+ "adv_mean_cos_to_docs": float(
113
+ cosine_similarity(adv_embedding.reshape(1, -1), doc_embs).mean()
114
+ ),
115
+ }
116
+
117
+ return adv_embedding, info
118
+
119
+
120
+ def compute_doconly_fiedler_placement(doc_embs: np.ndarray) -> tuple[np.ndarray, dict]:
121
+ """Document-only Fiedler placement (no query access).
122
+
123
+ Weaker attacker that only has document embeddings.
124
+ Uses document-space Laplacian L_D directly.
125
+ """
126
+ n = doc_embs.shape[0]
127
+ if n < 3:
128
+ centroid = doc_embs.mean(axis=0)
129
+ return centroid / np.linalg.norm(centroid), {"method": "centroid_fallback"}
130
+
131
+ L_D, _ = compute_document_laplacian(doc_embs)
132
+
133
+ k = min(3, n - 1)
134
+ eigenvalues, eigenvectors = eigsh(L_D, k=k, which="SM")
135
+ sorted_idx = np.argsort(eigenvalues)
136
+ fiedler_vec = eigenvectors[:, sorted_idx[1]]
137
+ fiedler_val = eigenvalues[sorted_idx[1]]
138
+
139
+ weights = np.abs(fiedler_vec)
140
+ weights = weights / (weights.sum() + 1e-12)
141
+
142
+ adv_embedding = (weights[:, None] * doc_embs).sum(axis=0)
143
+ norm = np.linalg.norm(adv_embedding)
144
+ if norm > 0:
145
+ adv_embedding = adv_embedding / norm
146
+
147
+ return adv_embedding, {
148
+ "method": "doconly_fiedler",
149
+ "fiedler_eigenvalue": float(fiedler_val),
150
+ }
151
+
152
+
153
+ def compute_asr_threshold(
154
+ query_embeddings: torch.Tensor,
155
+ corpus_embeddings: torch.Tensor,
156
+ adv_embedding: torch.Tensor,
157
+ operator,
158
+ top_k: int = 10,
159
+ device: str = "cpu",
160
+ batch_size: int = 50,
161
+ ) -> tuple[float, dict]:
162
+ """Compute ASR@k using per-query threshold (oracle-style).
163
+
164
+ For each query, the k-th highest corpus score is the threshold.
165
+ Attack succeeds if the adversarial doc's score >= threshold.
166
+ Mirrors gp_rag/plan_single.py oracle check.
167
+ """
168
+ num_queries = query_embeddings.shape[0]
169
+ corpus_emb = corpus_embeddings.to(device)
170
+ adv_emb = adv_embedding.to(device)
171
+
172
+ operator.eval()
173
+ successes = 0
174
+ margins = []
175
+
176
+ with torch.no_grad():
177
+ for start in range(0, num_queries, batch_size):
178
+ end = min(start + batch_size, num_queries)
179
+ q_batch = query_embeddings[start:end].to(device)
180
+ bs = q_batch.shape[0]
181
+
182
+ # Score adversarial document
183
+ adv_expanded = adv_emb.unsqueeze(0).expand(bs, -1)
184
+ adv_scores = operator(q_batch, adv_expanded)
185
+
186
+ # Score corpus documents
187
+ corpus_scores = operator.compute_pairwise(q_batch, corpus_emb)
188
+
189
+ # k-th highest corpus score = threshold
190
+ topk_vals, _ = torch.topk(corpus_scores, top_k, dim=1)
191
+ thresholds = topk_vals[:, -1]
192
+
193
+ for j in range(bs):
194
+ margin = float(adv_scores[j].item() - thresholds[j].item())
195
+ margins.append(margin)
196
+ if adv_scores[j] >= thresholds[j]:
197
+ successes += 1
198
+
199
+ asr = successes / num_queries
200
+ margins_arr = np.array(margins)
201
+ info = {
202
+ "mean_margin": float(margins_arr.mean()),
203
+ "median_margin": float(np.median(margins_arr)),
204
+ "p25_margin": float(np.percentile(margins_arr, 25)),
205
+ "fraction_positive_margin": float((margins_arr >= 0).mean()),
206
+ }
207
+
208
+ return asr, info
209
+
210
+
211
+ def run_attack_simulation(
212
+ encoder,
213
+ operator,
214
+ baseline_operator,
215
+ passages: list[str],
216
+ passage_embeddings_torch: torch.Tensor,
217
+ target_query_texts: list[str],
218
+ medical_keywords: list[str],
219
+ top_k: int = 10,
220
+ max_domain_docs: int = 200,
221
+ device: str = "cpu",
222
+ ) -> dict:
223
+ """Run GeoPoison-RAG attack simulation.
224
+
225
+ Tests two attacker models:
226
+ 1. Bipartite Fiedler (realistic): attacker has shadow queries + docs,
227
+ builds cosine-based bipartite graph, optimizes in document space.
228
+ 2. Doc-only Fiedler (weaker): attacker has only document embeddings.
229
+
230
+ Both assume cosine similarity governs retrieval. The defense breaks
231
+ this assumption via the cross-manifold operator R.
232
+ """
233
+ print("\n=== Attack Simulation ===", flush=True)
234
+
235
+ # Step 1: Select target domain documents
236
+ domain_indices, domain_texts = select_domain_documents(
237
+ passages, medical_keywords, max_domain_docs
238
+ )
239
+ print(f"Selected {len(domain_indices)} domain documents.", flush=True)
240
+
241
+ if len(domain_indices) < 5:
242
+ print("Warning: Too few domain documents found.")
243
+ return {"error": "insufficient domain documents"}
244
+
245
+ domain_embs_np = passage_embeddings_torch[domain_indices].cpu().numpy()
246
+ domain_corpus = passage_embeddings_torch[domain_indices]
247
+
248
+ # Step 2: Encode target queries (attacker's shadow queries)
249
+ print(f"Encoding {len(target_query_texts)} target queries...", flush=True)
250
+ query_embeddings = encoder.encode_queries(target_query_texts, show_progress=False)
251
+ q_np = query_embeddings.cpu().numpy()
252
+
253
+ # Step 3a: Bipartite Fiedler placement (realistic attacker)
254
+ print("\nComputing bipartite Fiedler placement (attacker has shadow queries)...", flush=True)
255
+ adv_bipartite_np, bp_info = build_bipartite_fiedler_placement(
256
+ q_np, domain_embs_np, t_nn=min(20, len(domain_indices) - 1)
257
+ )
258
+ adv_bipartite = torch.tensor(adv_bipartite_np, dtype=torch.float32)
259
+ print(f" Fiedler eigenvalue: {bp_info['fiedler_eigenvalue']:.6f}", flush=True)
260
+ print(f" Adv mean cos to queries: {bp_info['adv_mean_cos_to_queries']:.4f}", flush=True)
261
+ print(f" Adv mean cos to docs: {bp_info['adv_mean_cos_to_docs']:.4f}", flush=True)
262
+
263
+ # Step 3b: Doc-only Fiedler placement (weaker attacker)
264
+ print("\nComputing doc-only Fiedler placement (no query access)...", flush=True)
265
+ adv_doconly_np, do_info = compute_doconly_fiedler_placement(domain_embs_np)
266
+ adv_doconly = torch.tensor(adv_doconly_np, dtype=torch.float32)
267
+
268
+ # Step 4: Measure ASR for bipartite attack
269
+ print(f"\n--- Bipartite Fiedler Attack (realistic GeoPoison-RAG) ---", flush=True)
270
+
271
+ asr_bp_base, bp_base_info = compute_asr_threshold(
272
+ query_embeddings, domain_corpus, adv_bipartite,
273
+ baseline_operator, top_k, device
274
+ )
275
+ print(f" Baseline ASR@{top_k}: {asr_bp_base:.4f} (mean margin: {bp_base_info['mean_margin']:.4f})", flush=True)
276
+
277
+ asr_bp_mm, bp_mm_info = compute_asr_threshold(
278
+ query_embeddings, domain_corpus, adv_bipartite,
279
+ operator, top_k, device
280
+ )
281
+ print(f" Multi-manifold ASR@{top_k}: {asr_bp_mm:.4f} (mean margin: {bp_mm_info['mean_margin']:.4f})", flush=True)
282
+
283
+ # Step 5: Measure ASR for doc-only attack
284
+ print(f"\n--- Doc-only Fiedler Attack (weaker attacker) ---", flush=True)
285
+
286
+ asr_do_base, do_base_info = compute_asr_threshold(
287
+ query_embeddings, domain_corpus, adv_doconly,
288
+ baseline_operator, top_k, device
289
+ )
290
+ print(f" Baseline ASR@{top_k}: {asr_do_base:.4f} (mean margin: {do_base_info['mean_margin']:.4f})", flush=True)
291
+
292
+ asr_do_mm, do_mm_info = compute_asr_threshold(
293
+ query_embeddings, domain_corpus, adv_doconly,
294
+ operator, top_k, device
295
+ )
296
+ print(f" Multi-manifold ASR@{top_k}: {asr_do_mm:.4f} (mean margin: {do_mm_info['mean_margin']:.4f})", flush=True)
297
+
298
+ # Summary
299
+ results = {
300
+ "bipartite_attack": {
301
+ "baseline_asr": asr_bp_base,
302
+ "multi_manifold_asr": asr_bp_mm,
303
+ "baseline_margins": bp_base_info,
304
+ "multi_manifold_margins": bp_mm_info,
305
+ "placement_info": bp_info,
306
+ },
307
+ "doconly_attack": {
308
+ "baseline_asr": asr_do_base,
309
+ "multi_manifold_asr": asr_do_mm,
310
+ "baseline_margins": do_base_info,
311
+ "multi_manifold_margins": do_mm_info,
312
+ "placement_info": do_info,
313
+ },
314
+ "num_domain_docs": len(domain_indices),
315
+ "num_target_queries": len(target_query_texts),
316
+ "top_k": top_k,
317
+ # For backward compat with summary printing
318
+ "baseline_asr": asr_bp_base,
319
+ "multi_manifold_asr": asr_bp_mm,
320
+ }
321
+
322
+ def _reduction(base, mm):
323
+ return (1 - mm / max(base, 1e-9)) * 100
324
+
325
+ print(f"\n=== Attack Results Summary ===", flush=True)
326
+ print(f" Baseline Multi-Manifold Reduction", flush=True)
327
+ print(f" Bipartite (realistic): {asr_bp_base:.4f} {asr_bp_mm:.4f}"
328
+ f" {_reduction(asr_bp_base, asr_bp_mm):.1f}%", flush=True)
329
+ print(f" Doc-only (weaker): {asr_do_base:.4f} {asr_do_mm:.4f}"
330
+ f" {_reduction(asr_do_base, asr_do_mm):.1f}%", flush=True)
331
+
332
+ return results
multi_manifold_retrieval/evaluation/retrieval_metrics.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval evaluation metrics: MRR@10 and Recall@100."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def compute_retrieval_metrics(
8
+ query_embeddings: torch.Tensor,
9
+ doc_embeddings: torch.Tensor,
10
+ operator,
11
+ query_texts: list[str],
12
+ positive_passages: list[list[str]],
13
+ all_passages: list[str],
14
+ passage_embeddings: torch.Tensor,
15
+ device: str = "cpu",
16
+ batch_size: int = 32,
17
+ ) -> dict:
18
+ """Compute MRR@10 and Recall@100.
19
+
20
+ Args:
21
+ query_embeddings: (num_queries, d) query embeddings.
22
+ doc_embeddings: Not used directly (passage_embeddings used instead).
23
+ operator: Relevance operator (cross-manifold or baseline).
24
+ query_texts: List of query strings.
25
+ positive_passages: List of lists of positive passage texts per query.
26
+ all_passages: Flat list of all candidate passages.
27
+ passage_embeddings: (num_passages, d) embeddings for all_passages.
28
+ device: Computation device.
29
+ batch_size: Batch size for scoring.
30
+
31
+ Returns:
32
+ Dict with mrr@10 and recall@100.
33
+ """
34
+ num_queries = query_embeddings.shape[0]
35
+ num_passages = passage_embeddings.shape[0]
36
+
37
+ # Build positive passage index: for each query, which passage indices are relevant
38
+ passage_to_idx = {text: idx for idx, text in enumerate(all_passages)}
39
+ positive_indices = []
40
+ for pos_list in positive_passages:
41
+ indices = set()
42
+ for text in pos_list:
43
+ if text in passage_to_idx:
44
+ indices.add(passage_to_idx[text])
45
+ positive_indices.append(indices)
46
+
47
+ passage_embeddings = passage_embeddings.to(device)
48
+ operator.eval()
49
+
50
+ mrr_sum = 0.0
51
+ recall_100_sum = 0.0
52
+ valid_queries = 0
53
+
54
+ with torch.no_grad():
55
+ for i in range(0, num_queries, batch_size):
56
+ end = min(i + batch_size, num_queries)
57
+ q_batch = query_embeddings[i:end].to(device) # (bs, d)
58
+
59
+ # Score all passages: (bs, num_passages)
60
+ scores = operator.compute_pairwise(q_batch, passage_embeddings)
61
+ scores_np = scores.cpu().numpy()
62
+
63
+ for j in range(scores_np.shape[0]):
64
+ query_idx = i + j
65
+ pos_set = positive_indices[query_idx]
66
+ if not pos_set:
67
+ continue
68
+
69
+ # Rank by score (descending)
70
+ ranked = np.argsort(-scores_np[j])
71
+
72
+ # MRR@10
73
+ rr = 0.0
74
+ for rank, doc_idx in enumerate(ranked[:10]):
75
+ if doc_idx in pos_set:
76
+ rr = 1.0 / (rank + 1)
77
+ break
78
+ mrr_sum += rr
79
+
80
+ # Recall@100
81
+ top_100 = set(ranked[:100].tolist())
82
+ recall = len(pos_set & top_100) / len(pos_set)
83
+ recall_100_sum += recall
84
+
85
+ valid_queries += 1
86
+
87
+ mrr_at_10 = mrr_sum / valid_queries if valid_queries > 0 else 0.0
88
+ recall_at_100 = recall_100_sum / valid_queries if valid_queries > 0 else 0.0
89
+
90
+ results = {
91
+ "mrr@10": mrr_at_10,
92
+ "recall@100": recall_at_100,
93
+ "num_queries": valid_queries,
94
+ }
95
+
96
+ print(f"MRR@10: {mrr_at_10:.4f} | Recall@100: {recall_at_100:.4f} "
97
+ f"({valid_queries} queries)")
98
+
99
+ return results
multi_manifold_retrieval/evaluation/spectral_analysis.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Spectral analysis: compute L_D, L_R, spectral discrepancy δ, and Fiedler alignment cos(θ)."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from scipy import sparse
6
+ from scipy.sparse.linalg import eigsh
7
+
8
+
9
+ def compute_document_laplacian(doc_embeddings: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
10
+ """Compute the document-space graph Laplacian L_D.
11
+
12
+ Args:
13
+ doc_embeddings: (n, d) array of L2-normalized document embeddings.
14
+
15
+ Returns:
16
+ L_D: (n, n) graph Laplacian.
17
+ W_D: (n, n) cosine similarity matrix.
18
+ """
19
+ # Cosine similarity (embeddings are already L2-normalized)
20
+ W_D = doc_embeddings @ doc_embeddings.T
21
+
22
+ # Clip to [0, 1] to ensure non-negative weights
23
+ W_D = np.clip(W_D, 0, None)
24
+ np.fill_diagonal(W_D, 0) # No self-loops
25
+
26
+ # Degree matrix and Laplacian
27
+ degrees = W_D.sum(axis=1)
28
+ D_D = np.diag(degrees)
29
+ L_D = D_D - W_D
30
+
31
+ return L_D, W_D
32
+
33
+
34
+ def compute_retrieval_laplacian(
35
+ doc_embeddings: torch.Tensor,
36
+ query_embeddings: torch.Tensor,
37
+ operator,
38
+ device: str = "cpu",
39
+ batch_size: int = 50,
40
+ ) -> tuple[np.ndarray, np.ndarray]:
41
+ """Compute the retrieval Laplacian L_R.
42
+
43
+ (W_R)_{ij} = (1/|Q|) * sum_q R(q, d_i) * R(q, d_j)
44
+
45
+ Args:
46
+ doc_embeddings: (n, d) document embeddings tensor.
47
+ query_embeddings: (m, d) query embeddings tensor.
48
+ operator: Cross-manifold operator or baseline operator.
49
+ device: Computation device.
50
+ batch_size: Number of queries to process at once.
51
+
52
+ Returns:
53
+ L_R: (n, n) retrieval Laplacian.
54
+ W_R: (n, n) retrieval similarity matrix.
55
+ """
56
+ n = doc_embeddings.shape[0]
57
+ m = query_embeddings.shape[0]
58
+
59
+ doc_embeddings = doc_embeddings.to(device)
60
+ query_embeddings = query_embeddings.to(device)
61
+
62
+ # Accumulate W_R = (1/m) * R^T R where R_{ki} = R(q_k, d_i)
63
+ W_R = np.zeros((n, n), dtype=np.float64)
64
+
65
+ operator.eval()
66
+ with torch.no_grad():
67
+ for start in range(0, m, batch_size):
68
+ end = min(start + batch_size, m)
69
+ q_batch = query_embeddings[start:end] # (bs, d)
70
+
71
+ # Compute R(q, d) for all docs: (bs, n)
72
+ scores = operator.compute_pairwise(q_batch, doc_embeddings)
73
+ scores_np = scores.cpu().numpy().astype(np.float64) # (bs, n)
74
+
75
+ # Outer product accumulation: W_R += scores^T @ scores
76
+ W_R += scores_np.T @ scores_np
77
+
78
+ W_R /= m
79
+
80
+ # Ensure non-negative and zero diagonal
81
+ W_R = np.clip(W_R, 0, None)
82
+ np.fill_diagonal(W_R, 0)
83
+
84
+ # Laplacian
85
+ degrees = W_R.sum(axis=1)
86
+ D_R = np.diag(degrees)
87
+ L_R = D_R - W_R
88
+
89
+ return L_R, W_R
90
+
91
+
92
+ def compute_spectral_discrepancy(L_D: np.ndarray, L_R: np.ndarray,
93
+ num_eigenvalues: int = 50) -> float:
94
+ """Compute spectral discrepancy δ = ||σ(L_D) - σ(L_R)||_2.
95
+
96
+ Uses the smallest num_eigenvalues eigenvalues (normalized).
97
+
98
+ Args:
99
+ L_D: Document-space Laplacian.
100
+ L_R: Retrieval Laplacian.
101
+ num_eigenvalues: Number of eigenvalues to compare.
102
+
103
+ Returns:
104
+ δ: Spectral discrepancy.
105
+ """
106
+ n = L_D.shape[0]
107
+ k = min(num_eigenvalues, n - 2)
108
+
109
+ # Compute smallest eigenvalues (Laplacians have smallest eigenvalue = 0)
110
+ eigs_D = eigsh(L_D, k=k, which="SM", return_eigenvectors=False)
111
+ eigs_R = eigsh(L_R, k=k, which="SM", return_eigenvectors=False)
112
+
113
+ # Sort
114
+ eigs_D = np.sort(eigs_D)
115
+ eigs_R = np.sort(eigs_R)
116
+
117
+ # Normalize so max eigenvalue = 1
118
+ max_D = eigs_D[-1] if eigs_D[-1] > 0 else 1.0
119
+ max_R = eigs_R[-1] if eigs_R[-1] > 0 else 1.0
120
+ eigs_D_norm = eigs_D / max_D
121
+ eigs_R_norm = eigs_R / max_R
122
+
123
+ delta = np.linalg.norm(eigs_D_norm - eigs_R_norm)
124
+ return delta
125
+
126
+
127
+ def compute_fiedler_alignment(L_D: np.ndarray, L_R: np.ndarray) -> float:
128
+ """Compute Fiedler vector alignment cos(θ) = |v_2(L_D)^T v_2(L_R)| / (||v_2(L_D)|| * ||v_2(L_R)||).
129
+
130
+ Args:
131
+ L_D: Document-space Laplacian.
132
+ L_R: Retrieval Laplacian.
133
+
134
+ Returns:
135
+ cos(θ): Absolute cosine of angle between Fiedler vectors (1 = aligned, 0 = orthogonal).
136
+ """
137
+ # Compute the two smallest eigenvalues/vectors
138
+ _, vecs_D = eigsh(L_D, k=2, which="SM")
139
+ _, vecs_R = eigsh(L_R, k=2, which="SM")
140
+
141
+ # Fiedler vector = eigenvector for 2nd smallest eigenvalue (index 1 after sorting)
142
+ v2_D = vecs_D[:, 1]
143
+ v2_R = vecs_R[:, 1]
144
+
145
+ # Normalize
146
+ v2_D = v2_D / np.linalg.norm(v2_D)
147
+ v2_R = v2_R / np.linalg.norm(v2_R)
148
+
149
+ # Absolute cosine similarity
150
+ cos_theta = np.abs(np.dot(v2_D, v2_R))
151
+ return cos_theta
152
+
153
+
154
+ def run_spectral_analysis(
155
+ doc_embeddings_np: np.ndarray,
156
+ doc_embeddings_torch: torch.Tensor,
157
+ query_embeddings_torch: torch.Tensor,
158
+ operator,
159
+ baseline_operator,
160
+ device: str = "cpu",
161
+ ) -> dict:
162
+ """Run full spectral analysis for both multi-manifold and baseline.
163
+
164
+ Returns dict with all metrics.
165
+ """
166
+ print("Computing document-space Laplacian L_D...")
167
+ L_D, W_D = compute_document_laplacian(doc_embeddings_np)
168
+
169
+ print("Computing retrieval Laplacian L_R (multi-manifold)...")
170
+ L_R_mm, W_R_mm = compute_retrieval_laplacian(
171
+ doc_embeddings_torch, query_embeddings_torch, operator, device
172
+ )
173
+
174
+ print("Computing retrieval Laplacian L_R (baseline)...")
175
+ L_R_base, W_R_base = compute_retrieval_laplacian(
176
+ doc_embeddings_torch, query_embeddings_torch, baseline_operator, device
177
+ )
178
+
179
+ print("Computing spectral discrepancy and Fiedler alignment...")
180
+ num_eigs = min(50, doc_embeddings_np.shape[0] - 2)
181
+
182
+ delta_mm = compute_spectral_discrepancy(L_D, L_R_mm, num_eigs)
183
+ delta_base = compute_spectral_discrepancy(L_D, L_R_base, num_eigs)
184
+ cos_theta_mm = compute_fiedler_alignment(L_D, L_R_mm)
185
+ cos_theta_base = compute_fiedler_alignment(L_D, L_R_base)
186
+
187
+ results = {
188
+ "multi_manifold": {
189
+ "spectral_discrepancy": delta_mm,
190
+ "fiedler_alignment": cos_theta_mm,
191
+ },
192
+ "baseline": {
193
+ "spectral_discrepancy": delta_base,
194
+ "fiedler_alignment": cos_theta_base,
195
+ },
196
+ "L_D": L_D,
197
+ "L_R_mm": L_R_mm,
198
+ "L_R_base": L_R_base,
199
+ }
200
+
201
+ print(f"\n=== Spectral Analysis Results ===")
202
+ print(f"Multi-Manifold: δ = {delta_mm:.4f}, cos(θ) = {cos_theta_mm:.4f}")
203
+ print(f"Baseline: δ = {delta_base:.4f}, cos(θ) = {cos_theta_base:.4f}")
204
+
205
+ return results
multi_manifold_retrieval/models/__init__.py ADDED
File without changes
multi_manifold_retrieval/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (179 Bytes). View file
 
multi_manifold_retrieval/models/__pycache__/baseline.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
multi_manifold_retrieval/models/__pycache__/cross_manifold_operator.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
multi_manifold_retrieval/models/__pycache__/encoders.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
multi_manifold_retrieval/models/baseline.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standard dual encoder baseline: R(q, d) = cosine_similarity(q, d)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaselineOperator(nn.Module):
8
+ """Decomposable baseline: R(q, d) = q^T d (cosine similarity on normalized embeddings)."""
9
+
10
+ def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
11
+ """Compute cosine similarity.
12
+
13
+ Args:
14
+ q: Query embeddings, shape (batch_size, embedding_dim), L2-normalized.
15
+ d: Document embeddings, shape (batch_size, num_docs, embedding_dim) or
16
+ (batch_size, embedding_dim).
17
+
18
+ Returns:
19
+ Similarity scores.
20
+ """
21
+ if d.dim() == 2:
22
+ return torch.sum(q * d, dim=-1)
23
+ else:
24
+ return torch.einsum("bd,bnd->bn", q, d)
25
+
26
+ def compute_pairwise(self, q: torch.Tensor,
27
+ docs: torch.Tensor) -> torch.Tensor:
28
+ """Compute cosine similarity for all query-document pairs.
29
+
30
+ Args:
31
+ q: (num_queries, embedding_dim)
32
+ docs: (num_docs, embedding_dim)
33
+
34
+ Returns:
35
+ (num_queries, num_docs)
36
+ """
37
+ return torch.mm(q, docs.t())
multi_manifold_retrieval/models/cross_manifold_operator.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Construction C: Attention-Geometric Hybrid cross-manifold operator.
2
+
3
+ R(q, d) = sum_{h=1}^{H} softmax((W_Q^h q)^T (W_K^h d) / sqrt(d_h)) * v^h(q, d)
4
+
5
+ where v^h(q, d) is a learned query-dependent value function parameterized as
6
+ a small MLP taking [q; d; q * d] as input.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ValueMLP(nn.Module):
15
+ """Learned value function v^h(q, d) for a single attention head.
16
+
17
+ Takes concatenation [q; d; q * d] as input, outputs a scalar.
18
+ """
19
+
20
+ def __init__(self, embedding_dim: int, hidden_dim: int = 256,
21
+ num_layers: int = 2, dropout: float = 0.1):
22
+ super().__init__()
23
+ input_dim = 3 * embedding_dim # [q; d; q*d]
24
+ layers = []
25
+ in_dim = input_dim
26
+ for _ in range(num_layers):
27
+ layers.extend([
28
+ nn.Linear(in_dim, hidden_dim),
29
+ nn.GELU(),
30
+ nn.Dropout(dropout),
31
+ ])
32
+ in_dim = hidden_dim
33
+ layers.append(nn.Linear(hidden_dim, 1))
34
+ self.mlp = nn.Sequential(*layers)
35
+
36
+ def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
37
+ """Compute v^h(q, d).
38
+
39
+ Args:
40
+ q: Query embeddings, shape (batch, embed_dim) or (batch, num_docs, embed_dim)
41
+ d: Document embeddings, shape (batch, embed_dim) or (batch, num_docs, embed_dim)
42
+
43
+ Returns:
44
+ Scalar values, shape matching the batch/doc dimensions.
45
+ """
46
+ x = torch.cat([q, d, q * d], dim=-1)
47
+ return self.mlp(x).squeeze(-1)
48
+
49
+
50
+ class CrossManifoldOperator(nn.Module):
51
+ """Attention-Geometric Hybrid (Construction C).
52
+
53
+ Implements the cross-manifold relevance operator R(q, d) as a
54
+ multi-head attention mechanism with learned value functions.
55
+ """
56
+
57
+ def __init__(self, embedding_dim: int, num_heads: int = 4,
58
+ value_hidden_dim: int = 256, value_num_layers: int = 2,
59
+ dropout: float = 0.1):
60
+ super().__init__()
61
+ self.embedding_dim = embedding_dim
62
+ self.num_heads = num_heads
63
+ self.head_dim = embedding_dim // num_heads
64
+ assert embedding_dim % num_heads == 0, \
65
+ f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
66
+
67
+ # Per-head query and key projections
68
+ self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
69
+ self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
70
+
71
+ # Per-head value MLPs
72
+ self.value_mlps = nn.ModuleList([
73
+ ValueMLP(embedding_dim, value_hidden_dim, value_num_layers, dropout)
74
+ for _ in range(num_heads)
75
+ ])
76
+
77
+ self._init_weights()
78
+
79
+ def _init_weights(self):
80
+ nn.init.xavier_uniform_(self.W_Q.weight)
81
+ nn.init.xavier_uniform_(self.W_K.weight)
82
+
83
+ def forward(self, q: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
84
+ """Compute R(q, d) = cross-manifold relevance score.
85
+
86
+ Args:
87
+ q: Query embeddings, shape (batch_size, embedding_dim)
88
+ d: Document embeddings, shape (batch_size, num_docs, embedding_dim)
89
+ or (batch_size, embedding_dim) for single document.
90
+
91
+ Returns:
92
+ Relevance scores, shape (batch_size, num_docs) or (batch_size,).
93
+ """
94
+ single_doc = d.dim() == 2
95
+ if single_doc:
96
+ d = d.unsqueeze(1) # (batch, 1, embed_dim)
97
+
98
+ batch_size, num_docs, _ = d.shape
99
+
100
+ # Project queries and keys: (batch, embed_dim) -> (batch, num_heads, head_dim)
101
+ q_proj = self.W_Q(q).view(batch_size, self.num_heads, self.head_dim)
102
+ # (batch, num_docs, embed_dim) -> (batch, num_docs, num_heads, head_dim)
103
+ d_proj = self.W_K(d).view(batch_size, num_docs, self.num_heads, self.head_dim)
104
+
105
+ # Attention scores: (batch, num_docs, num_heads)
106
+ scale = self.head_dim ** 0.5
107
+ attn = torch.einsum("bhd,bnhd->bnh", q_proj, d_proj) / scale
108
+
109
+ # Softmax over heads (not over documents) — each head contributes a
110
+ # weighted value, and the weighting is query-key dependent.
111
+ attn_weights = F.softmax(attn, dim=-1) # (batch, num_docs, num_heads)
112
+
113
+ # Expand q for value MLPs: (batch, num_docs, embed_dim)
114
+ q_expanded = q.unsqueeze(1).expand(-1, num_docs, -1)
115
+
116
+ # Compute per-head values and weight them
117
+ total = torch.zeros(batch_size, num_docs, device=q.device)
118
+ for h in range(self.num_heads):
119
+ v_h = self.value_mlps[h](q_expanded, d) # (batch, num_docs)
120
+ total = total + attn_weights[:, :, h] * v_h
121
+
122
+ if single_doc:
123
+ total = total.squeeze(1)
124
+
125
+ return total
126
+
127
+ def compute_pairwise(self, q: torch.Tensor,
128
+ docs: torch.Tensor) -> torch.Tensor:
129
+ """Compute R(q, d) for all query-document pairs.
130
+
131
+ Args:
132
+ q: Query embeddings, shape (num_queries, embedding_dim)
133
+ docs: Document embeddings, shape (num_docs, embedding_dim)
134
+
135
+ Returns:
136
+ Relevance matrix, shape (num_queries, num_docs).
137
+ """
138
+ # Expand docs for each query
139
+ num_queries = q.shape[0]
140
+ num_docs = docs.shape[0]
141
+ docs_expanded = docs.unsqueeze(0).expand(num_queries, -1, -1)
142
+ return self.forward(q, docs_expanded)
multi_manifold_retrieval/models/encoders.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper around sentence-transformers for document and query encoders."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ class DualEncoder(nn.Module):
9
+ """Frozen pretrained encoders for query and document manifolds.
10
+
11
+ Uses the same pretrained model for both query and document encoding
12
+ (separate manifolds are induced by the cross-manifold operator, not
13
+ by separate encoder weights).
14
+ """
15
+
16
+ def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
17
+ max_seq_length: int = 128, freeze: bool = True):
18
+ super().__init__()
19
+ self.model = SentenceTransformer(model_name)
20
+ self.model.max_seq_length = max_seq_length
21
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
22
+
23
+ if freeze:
24
+ for param in self.model.parameters():
25
+ param.requires_grad = False
26
+
27
+ def encode_queries(self, texts: list[str], batch_size: int = 64,
28
+ show_progress: bool = False) -> torch.Tensor:
29
+ """Encode query texts to embeddings on M_Q."""
30
+ embeddings = self.model.encode(
31
+ texts, batch_size=batch_size, show_progress_bar=show_progress,
32
+ convert_to_tensor=True, normalize_embeddings=True,
33
+ )
34
+ return embeddings
35
+
36
+ def encode_documents(self, texts: list[str], batch_size: int = 64,
37
+ show_progress: bool = False) -> torch.Tensor:
38
+ """Encode document texts to embeddings on M_D."""
39
+ embeddings = self.model.encode(
40
+ texts, batch_size=batch_size, show_progress_bar=show_progress,
41
+ convert_to_tensor=True, normalize_embeddings=True,
42
+ )
43
+ return embeddings
44
+
45
+ def forward_queries(self, input_ids: torch.Tensor,
46
+ attention_mask: torch.Tensor) -> torch.Tensor:
47
+ """Forward pass for query token IDs (for training)."""
48
+ features = {"input_ids": input_ids, "attention_mask": attention_mask}
49
+ out = self.model.forward(features)
50
+ embeddings = out["sentence_embedding"]
51
+ return nn.functional.normalize(embeddings, p=2, dim=-1)
52
+
53
+ def forward_documents(self, input_ids: torch.Tensor,
54
+ attention_mask: torch.Tensor) -> torch.Tensor:
55
+ """Forward pass for document token IDs (for training)."""
56
+ return self.forward_queries(input_ids, attention_mask)
57
+
58
+ @property
59
+ def tokenizer(self):
60
+ return self.model.tokenizer
multi_manifold_retrieval/training/__init__.py ADDED
File without changes
multi_manifold_retrieval/training/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (181 Bytes). View file
 
multi_manifold_retrieval/training/__pycache__/data.cpython-310.pyc ADDED
Binary file (5.49 kB). View file
 
multi_manifold_retrieval/training/__pycache__/losses.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
multi_manifold_retrieval/training/__pycache__/train.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
multi_manifold_retrieval/training/data.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MS MARCO data loading for training and evaluation."""
2
+
3
+ import random
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from datasets import load_dataset
9
+
10
+
11
+ class MSMARCOTripleDataset(Dataset):
12
+ """MS MARCO passage ranking dataset with hard negatives.
13
+
14
+ Each example yields (query, positive_passage, [negative_passages]).
15
+ """
16
+
17
+ def __init__(self, tokenizer, max_samples: int = 100_000,
18
+ num_negatives: int = 7, max_seq_length: int = 128,
19
+ split: str = "train", seed: int = 42):
20
+ self.tokenizer = tokenizer
21
+ self.max_seq_length = max_seq_length
22
+ self.num_negatives = num_negatives
23
+
24
+ # Load MS MARCO dataset
25
+ print(f"Loading MS MARCO ({split} split, max {max_samples} samples)...")
26
+ dataset = load_dataset("ms_marco", "v2.1", split=split, trust_remote_code=True)
27
+
28
+ # Filter to examples with at least one selected passage
29
+ self.examples = []
30
+ for i, ex in enumerate(dataset):
31
+ if len(self.examples) >= max_samples:
32
+ break
33
+ passages = ex["passages"]
34
+ selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
35
+ if selected:
36
+ self.examples.append({
37
+ "query": ex["query"],
38
+ "positive": passages["passage_text"][selected[0]],
39
+ "negatives": [
40
+ passages["passage_text"][j]
41
+ for j in range(len(passages["passage_text"]))
42
+ if j not in selected
43
+ ],
44
+ })
45
+
46
+ print(f"Loaded {len(self.examples)} training examples.")
47
+ self.rng = random.Random(seed)
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.examples)
51
+
52
+ def __getitem__(self, idx: int) -> dict:
53
+ ex = self.examples[idx]
54
+ # Sample negatives (from in-passage negatives, pad with random if needed)
55
+ available_negs = ex["negatives"]
56
+ if len(available_negs) >= self.num_negatives:
57
+ negs = self.rng.sample(available_negs, self.num_negatives)
58
+ else:
59
+ negs = available_negs[:]
60
+ # Pad with random negatives from other examples
61
+ while len(negs) < self.num_negatives:
62
+ rand_ex = self.examples[self.rng.randint(0, len(self.examples) - 1)]
63
+ if rand_ex["positive"] != ex["positive"]:
64
+ negs.append(rand_ex["positive"])
65
+
66
+ return {
67
+ "query": ex["query"],
68
+ "positive": ex["positive"],
69
+ "negatives": negs,
70
+ }
71
+
72
+
73
+ def collate_fn(batch: list[dict], tokenizer, max_seq_length: int = 128) -> dict:
74
+ """Collate batch into tokenized tensors."""
75
+ queries = [b["query"] for b in batch]
76
+ positives = [b["positive"] for b in batch]
77
+ all_negatives = []
78
+ for b in batch:
79
+ all_negatives.extend(b["negatives"])
80
+
81
+ # Tokenize
82
+ q_enc = tokenizer(
83
+ queries, padding=True, truncation=True,
84
+ max_length=max_seq_length, return_tensors="pt",
85
+ )
86
+ p_enc = tokenizer(
87
+ positives, padding=True, truncation=True,
88
+ max_length=max_seq_length, return_tensors="pt",
89
+ )
90
+ n_enc = tokenizer(
91
+ all_negatives, padding=True, truncation=True,
92
+ max_length=max_seq_length, return_tensors="pt",
93
+ )
94
+
95
+ num_negatives = len(batch[0]["negatives"])
96
+ return {
97
+ "query_input_ids": q_enc["input_ids"],
98
+ "query_attention_mask": q_enc["attention_mask"],
99
+ "pos_input_ids": p_enc["input_ids"],
100
+ "pos_attention_mask": p_enc["attention_mask"],
101
+ "neg_input_ids": n_enc["input_ids"],
102
+ "neg_attention_mask": n_enc["attention_mask"],
103
+ "num_negatives": num_negatives,
104
+ }
105
+
106
+
107
+ def get_dataloader(tokenizer, max_samples: int = 100_000,
108
+ num_negatives: int = 7, batch_size: int = 64,
109
+ max_seq_length: int = 128, split: str = "train",
110
+ seed: int = 42, num_workers: int = 0) -> DataLoader:
111
+ """Create a DataLoader for MS MARCO training."""
112
+ dataset = MSMARCOTripleDataset(
113
+ tokenizer=tokenizer, max_samples=max_samples,
114
+ num_negatives=num_negatives, max_seq_length=max_seq_length,
115
+ split=split, seed=seed,
116
+ )
117
+
118
+ def _collate(batch):
119
+ return collate_fn(batch, tokenizer, max_seq_length)
120
+
121
+ return DataLoader(
122
+ dataset, batch_size=batch_size, shuffle=True,
123
+ collate_fn=_collate, num_workers=num_workers,
124
+ drop_last=True,
125
+ )
126
+
127
+
128
+ class MSMARCOEvalDataset:
129
+ """MS MARCO dev set for evaluation."""
130
+
131
+ def __init__(self, tokenizer, max_queries: int = 5000,
132
+ max_seq_length: int = 128, seed: int = 42):
133
+ self.tokenizer = tokenizer
134
+ self.max_seq_length = max_seq_length
135
+
136
+ print(f"Loading MS MARCO dev set (max {max_queries} queries)...")
137
+ dataset = load_dataset("ms_marco", "v2.1", split="validation", trust_remote_code=True)
138
+
139
+ self.queries = []
140
+ self.positives = [] # list of list of positive passage texts
141
+ self.all_passages = [] # flat list of all passages for retrieval
142
+ self.passage_set = set()
143
+
144
+ rng = random.Random(seed)
145
+ indices = list(range(len(dataset)))
146
+ rng.shuffle(indices)
147
+
148
+ for i in indices:
149
+ if len(self.queries) >= max_queries:
150
+ break
151
+ ex = dataset[i]
152
+ passages = ex["passages"]
153
+ selected = [j for j, s in enumerate(passages["is_selected"]) if s == 1]
154
+ if not selected:
155
+ continue
156
+
157
+ self.queries.append(ex["query"])
158
+ pos_texts = [passages["passage_text"][j] for j in selected]
159
+ self.positives.append(pos_texts)
160
+
161
+ # Add all passages to the corpus
162
+ for text in passages["passage_text"]:
163
+ if text not in self.passage_set:
164
+ self.passage_set.add(text)
165
+ self.all_passages.append(text)
166
+
167
+ print(f"Loaded {len(self.queries)} eval queries, "
168
+ f"{len(self.all_passages)} unique passages.")
multi_manifold_retrieval/training/losses.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contrastive loss for training the cross-manifold operator."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ContrastiveLoss(nn.Module):
9
+ """Cross-entropy contrastive loss over (query, positive, negatives) triples.
10
+
11
+ Given a query q, positive document d+, and negative documents d1-, ..., dK-,
12
+ the loss is:
13
+ -log(exp(R(q, d+)/tau) / (exp(R(q, d+)/tau) + sum_k exp(R(q, dk-)/tau)))
14
+ """
15
+
16
+ def __init__(self, temperature: float = 0.05):
17
+ super().__init__()
18
+ self.temperature = temperature
19
+
20
+ def forward(self, pos_scores: torch.Tensor,
21
+ neg_scores: torch.Tensor) -> torch.Tensor:
22
+ """Compute contrastive loss.
23
+
24
+ Args:
25
+ pos_scores: (batch_size,) — R(q, d+) for each query.
26
+ neg_scores: (batch_size, num_negatives) — R(q, dk-) for each query.
27
+
28
+ Returns:
29
+ Scalar loss.
30
+ """
31
+ # Concatenate: (batch_size, 1 + num_negatives)
32
+ all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
33
+ all_scores = all_scores / self.temperature
34
+
35
+ # Target: index 0 is the positive
36
+ targets = torch.zeros(all_scores.shape[0], dtype=torch.long,
37
+ device=all_scores.device)
38
+ return F.cross_entropy(all_scores, targets)
multi_manifold_retrieval/training/train.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training loop for the cross-manifold operator."""
2
+
3
+ import os
4
+ import time
5
+ import yaml
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import OneCycleLR
11
+
12
+ from multi_manifold_retrieval.models.encoders import DualEncoder
13
+ from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator
14
+ from multi_manifold_retrieval.training.data import get_dataloader
15
+ from multi_manifold_retrieval.training.losses import ContrastiveLoss
16
+
17
+
18
+ def train(config_path: str = "configs/default.yaml",
19
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"):
20
+ """Train the cross-manifold operator on MS MARCO."""
21
+ with open(config_path) as f:
22
+ config = yaml.safe_load(f)
23
+
24
+ torch.manual_seed(config["seed"])
25
+
26
+ # Initialize encoder
27
+ print("Initializing encoder...", flush=True)
28
+ encoder = DualEncoder(
29
+ model_name=config["encoder"]["model_name"],
30
+ max_seq_length=config["training"]["max_seq_length"],
31
+ freeze=config["encoder"]["freeze"],
32
+ )
33
+ embedding_dim = encoder.embedding_dim
34
+
35
+ # Initialize cross-manifold operator
36
+ cm_config = config["cross_manifold"]
37
+ operator = CrossManifoldOperator(
38
+ embedding_dim=embedding_dim,
39
+ num_heads=cm_config["num_heads"],
40
+ value_hidden_dim=cm_config["value_mlp_hidden"],
41
+ value_num_layers=cm_config["value_mlp_layers"],
42
+ dropout=cm_config["dropout"],
43
+ ).to(device)
44
+
45
+ print(f"Encoder initialized (dim={embedding_dim}).", flush=True)
46
+
47
+ # Loss
48
+ loss_fn = ContrastiveLoss(temperature=0.05)
49
+
50
+ # Data
51
+ print("Loading training data...", flush=True)
52
+ t0 = time.time()
53
+ train_loader = get_dataloader(
54
+ tokenizer=encoder.tokenizer,
55
+ max_samples=config["training"]["max_train_samples"],
56
+ num_negatives=config["training"]["num_negatives"],
57
+ batch_size=config["training"]["batch_size"],
58
+ max_seq_length=config["training"]["max_seq_length"],
59
+ split="train",
60
+ seed=config["seed"],
61
+ )
62
+
63
+ print(f"Data loaded in {time.time()-t0:.1f}s. Batches per epoch: {len(train_loader)}", flush=True)
64
+
65
+ # Optimizer (only train the cross-manifold operator)
66
+ optimizer = AdamW(
67
+ operator.parameters(),
68
+ lr=config["training"]["learning_rate"],
69
+ weight_decay=config["training"]["weight_decay"],
70
+ )
71
+
72
+ total_steps = len(train_loader) * config["training"]["epochs"]
73
+ scheduler = OneCycleLR(
74
+ optimizer,
75
+ max_lr=config["training"]["learning_rate"],
76
+ total_steps=total_steps,
77
+ pct_start=min(config["training"]["warmup_steps"] / total_steps, 0.1),
78
+ )
79
+
80
+ # Training loop
81
+ print(f"Moving encoder to {device}...", flush=True)
82
+ encoder.model.to(device)
83
+ encoder.model.eval()
84
+ operator.train()
85
+ print(f"Starting training: {config['training']['epochs']} epochs, {total_steps} total steps", flush=True)
86
+
87
+ save_dir = config["training"]["save_dir"]
88
+ os.makedirs(save_dir, exist_ok=True)
89
+ log_every = config["training"]["log_every"]
90
+
91
+ global_step = 0
92
+ best_loss = float("inf")
93
+
94
+ for epoch in range(config["training"]["epochs"]):
95
+ epoch_loss = 0.0
96
+ epoch_start = time.time()
97
+
98
+ for batch_idx, batch in enumerate(train_loader):
99
+ # Move to device
100
+ q_ids = batch["query_input_ids"].to(device)
101
+ q_mask = batch["query_attention_mask"].to(device)
102
+ p_ids = batch["pos_input_ids"].to(device)
103
+ p_mask = batch["pos_attention_mask"].to(device)
104
+ n_ids = batch["neg_input_ids"].to(device)
105
+ n_mask = batch["neg_attention_mask"].to(device)
106
+ num_neg = batch["num_negatives"]
107
+
108
+ # Encode (no grad for frozen encoder)
109
+ with torch.no_grad():
110
+ q_emb = encoder.forward_queries(q_ids, q_mask) # (B, D)
111
+ p_emb = encoder.forward_documents(p_ids, p_mask) # (B, D)
112
+ n_emb = encoder.forward_documents(n_ids, n_mask) # (B*K, D)
113
+
114
+ batch_size = q_emb.shape[0]
115
+ n_emb = n_emb.view(batch_size, num_neg, -1) # (B, K, D)
116
+
117
+ # Compute relevance scores via cross-manifold operator
118
+ pos_scores = operator(q_emb, p_emb) # (B,)
119
+ neg_scores = operator(q_emb, n_emb) # (B, K)
120
+
121
+ # Loss
122
+ loss = loss_fn(pos_scores, neg_scores)
123
+
124
+ # Backward
125
+ optimizer.zero_grad()
126
+ loss.backward()
127
+ torch.nn.utils.clip_grad_norm_(operator.parameters(), 1.0)
128
+ optimizer.step()
129
+ scheduler.step()
130
+
131
+ epoch_loss += loss.item()
132
+ global_step += 1
133
+
134
+ if global_step % log_every == 0:
135
+ avg_loss = epoch_loss / (batch_idx + 1)
136
+ lr = scheduler.get_last_lr()[0]
137
+ print(f" Step {global_step} | Loss: {loss.item():.4f} | "
138
+ f"Avg: {avg_loss:.4f} | LR: {lr:.2e}", flush=True)
139
+
140
+ epoch_time = time.time() - epoch_start
141
+ avg_loss = epoch_loss / len(train_loader)
142
+ print(f"Epoch {epoch+1}/{config['training']['epochs']} | "
143
+ f"Avg Loss: {avg_loss:.4f} | Time: {epoch_time:.1f}s", flush=True)
144
+
145
+ # Save best
146
+ if avg_loss < best_loss:
147
+ best_loss = avg_loss
148
+ torch.save(operator.state_dict(), os.path.join(save_dir, "best_operator.pt"))
149
+ print(f" Saved best model (loss={best_loss:.4f})")
150
+
151
+ # Save final
152
+ torch.save(operator.state_dict(), os.path.join(save_dir, "final_operator.pt"))
153
+ print(f"Training complete. Best loss: {best_loss:.4f}")
154
+
155
+ return encoder, operator
156
+
157
+
158
+ if __name__ == "__main__":
159
+ train()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.30
3
+ sentence-transformers>=2.2
4
+ datasets>=2.14
5
+ faiss-cpu>=1.7
6
+ numpy>=1.24
7
+ scipy>=1.10
8
+ scikit-learn>=1.3
9
+ pyyaml>=6.0
10
+ tqdm>=4.65
results.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "seed": 42,
4
+ "encoder": {
5
+ "model_name": "sentence-transformers/all-MiniLM-L6-v2",
6
+ "embedding_dim": 384,
7
+ "freeze": true
8
+ },
9
+ "cross_manifold": {
10
+ "num_heads": 4,
11
+ "head_dim": 96,
12
+ "value_mlp_hidden": 256,
13
+ "value_mlp_layers": 2,
14
+ "dropout": 0.1
15
+ },
16
+ "training": {
17
+ "batch_size": 64,
18
+ "learning_rate": 0.0002,
19
+ "weight_decay": 0.01,
20
+ "epochs": 5,
21
+ "warmup_steps": 500,
22
+ "max_train_samples": 100000,
23
+ "num_negatives": 7,
24
+ "max_seq_length": 128,
25
+ "fp16": true,
26
+ "gradient_accumulation_steps": 1,
27
+ "log_every": 100,
28
+ "eval_every": 2000,
29
+ "save_dir": "checkpoints"
30
+ },
31
+ "evaluation": {
32
+ "max_eval_queries": 5000,
33
+ "metrics": [
34
+ "mrr@10",
35
+ "recall@100"
36
+ ]
37
+ },
38
+ "spectral": {
39
+ "num_documents": 1000,
40
+ "num_queries": 500,
41
+ "k_neighbors": 20
42
+ },
43
+ "attack": {
44
+ "target_domain": "medical",
45
+ "num_target_queries": 100,
46
+ "top_k": 10,
47
+ "medical_keywords": [
48
+ "health",
49
+ "medical",
50
+ "doctor",
51
+ "patient",
52
+ "treatment",
53
+ "disease",
54
+ "symptom",
55
+ "diagnosis",
56
+ "medicine",
57
+ "clinical"
58
+ ]
59
+ }
60
+ },
61
+ "device": "cuda",
62
+ "retrieval_multi_manifold": {
63
+ "mrr@10": 0.6002776984126992,
64
+ "recall@100": 0.9901,
65
+ "num_queries": 5000
66
+ },
67
+ "retrieval_baseline": {
68
+ "mrr@10": 0.5828701587301599,
69
+ "recall@100": 0.9942,
70
+ "num_queries": 5000
71
+ },
72
+ "mrr_ratio": 1.0298652099816936,
73
+ "spectral": {
74
+ "multi_manifold": {
75
+ "spectral_discrepancy": 0.05735765351097603,
76
+ "fiedler_alignment": 0.03973450829139222
77
+ },
78
+ "baseline": {
79
+ "spectral_discrepancy": 0.22395483470893326,
80
+ "fiedler_alignment": 0.7848795751112227
81
+ },
82
+ "num_documents": 1000,
83
+ "num_queries": 500
84
+ },
85
+ "attack": {
86
+ "bipartite_attack": {
87
+ "baseline_asr": 0.51,
88
+ "multi_manifold_asr": 0.19,
89
+ "baseline_margins": {
90
+ "mean_margin": -0.00035160839557647706,
91
+ "median_margin": 0.0012769699096679688,
92
+ "p25_margin": -0.03736262768507004,
93
+ "fraction_positive_margin": 0.51
94
+ },
95
+ "multi_manifold_margins": {
96
+ "mean_margin": -0.0419993931055069,
97
+ "median_margin": -0.04468509554862976,
98
+ "p25_margin": -0.06919527053833008,
99
+ "fraction_positive_margin": 0.19
100
+ },
101
+ "placement_info": {
102
+ "method": "bipartite_fiedler",
103
+ "fiedler_eigenvalue": 0.12380694040243122,
104
+ "weight_entropy": 4.995662314784581,
105
+ "max_weight": 0.016811827898423226,
106
+ "adv_mean_cos_to_queries": 0.2539085502225744,
107
+ "adv_mean_cos_to_docs": 0.2686595998305845
108
+ }
109
+ },
110
+ "doconly_attack": {
111
+ "baseline_asr": 0.03,
112
+ "multi_manifold_asr": 0.03,
113
+ "baseline_margins": {
114
+ "mean_margin": -0.1627955549955368,
115
+ "median_margin": -0.1696268543601036,
116
+ "p25_margin": -0.21040004305541515,
117
+ "fraction_positive_margin": 0.03
118
+ },
119
+ "multi_manifold_margins": {
120
+ "mean_margin": -0.12464183956384658,
121
+ "median_margin": -0.12341519445180893,
122
+ "p25_margin": -0.1708945743739605,
123
+ "fraction_positive_margin": 0.03
124
+ },
125
+ "placement_info": {
126
+ "method": "doconly_fiedler",
127
+ "fiedler_eigenvalue": 3.8886430263519287
128
+ }
129
+ },
130
+ "num_domain_docs": 200,
131
+ "num_target_queries": 100,
132
+ "top_k": 10,
133
+ "baseline_asr": 0.51,
134
+ "multi_manifold_asr": 0.19
135
+ }
136
+ }
run_experiment.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """End-to-end experiment: train → evaluate → spectral analysis → attack simulation."""
3
+
4
+ import os
5
+ import json
6
+ import random
7
+ import time
8
+ import yaml
9
+ import argparse
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from multi_manifold_retrieval.models.encoders import DualEncoder
15
+ from multi_manifold_retrieval.models.cross_manifold_operator import CrossManifoldOperator
16
+ from multi_manifold_retrieval.models.baseline import BaselineOperator
17
+ from multi_manifold_retrieval.training.train import train
18
+ from multi_manifold_retrieval.training.data import MSMARCOEvalDataset
19
+ from multi_manifold_retrieval.evaluation.spectral_analysis import run_spectral_analysis
20
+ from multi_manifold_retrieval.evaluation.retrieval_metrics import compute_retrieval_metrics
21
+ from multi_manifold_retrieval.evaluation.attack_simulation import (
22
+ run_attack_simulation,
23
+ select_domain_documents,
24
+ )
25
+
26
+
27
+ def set_seed(seed: int):
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ if torch.cuda.is_available():
32
+ torch.cuda.manual_seed_all(seed)
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser(description="Multi-Manifold Retrieval PoC Experiment")
37
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
38
+ parser.add_argument("--skip-train", action="store_true", help="Skip training, load from checkpoint")
39
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/best_operator.pt")
40
+ parser.add_argument("--output", type=str, default="results.json")
41
+ args = parser.parse_args()
42
+
43
+ with open(args.config) as f:
44
+ config = yaml.safe_load(f)
45
+
46
+ set_seed(config["seed"])
47
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
48
+ print(f"Using device: {device}")
49
+
50
+ results = {"config": config, "device": device}
51
+
52
+ # =========================================================================
53
+ # Phase 1: Training
54
+ # =========================================================================
55
+ print("\n" + "=" * 60)
56
+ print("PHASE 1: TRAINING")
57
+ print("=" * 60)
58
+
59
+ if args.skip_train and os.path.exists(args.checkpoint):
60
+ print(f"Loading encoder and operator from checkpoint: {args.checkpoint}")
61
+ encoder = DualEncoder(
62
+ model_name=config["encoder"]["model_name"],
63
+ max_seq_length=config["training"]["max_seq_length"],
64
+ freeze=config["encoder"]["freeze"],
65
+ )
66
+ cm_config = config["cross_manifold"]
67
+ operator = CrossManifoldOperator(
68
+ embedding_dim=encoder.embedding_dim,
69
+ num_heads=cm_config["num_heads"],
70
+ value_hidden_dim=cm_config["value_mlp_hidden"],
71
+ value_num_layers=cm_config["value_mlp_layers"],
72
+ dropout=cm_config["dropout"],
73
+ )
74
+ operator.load_state_dict(torch.load(args.checkpoint, map_location=device, weights_only=True))
75
+ operator.to(device)
76
+ else:
77
+ encoder, operator = train(config_path=args.config, device=device)
78
+
79
+ encoder.model.to(device)
80
+ operator.to(device)
81
+ baseline_operator = BaselineOperator().to(device)
82
+
83
+ # =========================================================================
84
+ # Phase 2: Evaluation Data Preparation
85
+ # =========================================================================
86
+ print("\n" + "=" * 60)
87
+ print("PHASE 2: EVALUATION DATA PREPARATION")
88
+ print("=" * 60)
89
+
90
+ eval_data = MSMARCOEvalDataset(
91
+ tokenizer=encoder.tokenizer,
92
+ max_queries=config["evaluation"]["max_eval_queries"],
93
+ max_seq_length=config["training"]["max_seq_length"],
94
+ seed=config["seed"],
95
+ )
96
+
97
+ # Encode all evaluation passages
98
+ print(f"Encoding {len(eval_data.all_passages)} passages...")
99
+ passage_embeddings = encoder.encode_documents(
100
+ eval_data.all_passages, batch_size=128, show_progress=True,
101
+ )
102
+
103
+ # Encode evaluation queries
104
+ print(f"Encoding {len(eval_data.queries)} queries...")
105
+ query_embeddings = encoder.encode_queries(
106
+ eval_data.queries, batch_size=128, show_progress=True,
107
+ )
108
+
109
+ # =========================================================================
110
+ # Phase 3: Retrieval Quality Evaluation
111
+ # =========================================================================
112
+ print("\n" + "=" * 60)
113
+ print("PHASE 3: RETRIEVAL QUALITY")
114
+ print("=" * 60)
115
+
116
+ print("\n--- Multi-Manifold Model ---")
117
+ metrics_mm = compute_retrieval_metrics(
118
+ query_embeddings=query_embeddings,
119
+ doc_embeddings=passage_embeddings,
120
+ operator=operator,
121
+ query_texts=eval_data.queries,
122
+ positive_passages=eval_data.positives,
123
+ all_passages=eval_data.all_passages,
124
+ passage_embeddings=passage_embeddings,
125
+ device=device,
126
+ )
127
+ results["retrieval_multi_manifold"] = metrics_mm
128
+
129
+ print("\n--- Baseline (Cosine Similarity) ---")
130
+ metrics_base = compute_retrieval_metrics(
131
+ query_embeddings=query_embeddings,
132
+ doc_embeddings=passage_embeddings,
133
+ operator=baseline_operator,
134
+ query_texts=eval_data.queries,
135
+ positive_passages=eval_data.positives,
136
+ all_passages=eval_data.all_passages,
137
+ passage_embeddings=passage_embeddings,
138
+ device=device,
139
+ )
140
+ results["retrieval_baseline"] = metrics_base
141
+
142
+ # Check: multi-manifold within 80% of baseline
143
+ if metrics_base["mrr@10"] > 0:
144
+ ratio = metrics_mm["mrr@10"] / metrics_base["mrr@10"]
145
+ print(f"\nMRR@10 ratio (mm/baseline): {ratio:.4f} "
146
+ f"({'PASS' if ratio >= 0.8 else 'BELOW TARGET'}, target >= 0.8)")
147
+ results["mrr_ratio"] = ratio
148
+
149
+ # =========================================================================
150
+ # Phase 4: Spectral Analysis
151
+ # =========================================================================
152
+ print("\n" + "=" * 60)
153
+ print("PHASE 4: SPECTRAL ANALYSIS")
154
+ print("=" * 60)
155
+
156
+ # Sample documents for spectral analysis
157
+ num_spectral_docs = min(config["spectral"]["num_documents"], len(eval_data.all_passages))
158
+ num_spectral_queries = min(config["spectral"]["num_queries"], len(eval_data.queries))
159
+
160
+ spectral_doc_indices = np.random.choice(
161
+ len(eval_data.all_passages), num_spectral_docs, replace=False
162
+ )
163
+ spectral_query_indices = np.random.choice(
164
+ len(eval_data.queries), num_spectral_queries, replace=False
165
+ )
166
+
167
+ spectral_doc_emb_np = passage_embeddings[spectral_doc_indices].cpu().numpy()
168
+ spectral_doc_emb_torch = passage_embeddings[spectral_doc_indices]
169
+ spectral_query_emb_torch = query_embeddings[spectral_query_indices]
170
+
171
+ spectral_results = run_spectral_analysis(
172
+ doc_embeddings_np=spectral_doc_emb_np,
173
+ doc_embeddings_torch=spectral_doc_emb_torch,
174
+ query_embeddings_torch=spectral_query_emb_torch,
175
+ operator=operator,
176
+ baseline_operator=baseline_operator,
177
+ device=device,
178
+ )
179
+
180
+ results["spectral"] = {
181
+ "multi_manifold": spectral_results["multi_manifold"],
182
+ "baseline": spectral_results["baseline"],
183
+ "num_documents": num_spectral_docs,
184
+ "num_queries": num_spectral_queries,
185
+ }
186
+
187
+ # =========================================================================
188
+ # Phase 5: Attack Simulation
189
+ # =========================================================================
190
+ print("\n" + "=" * 60)
191
+ print("PHASE 5: ATTACK SIMULATION")
192
+ print("=" * 60)
193
+
194
+ attack_config = config["attack"]
195
+
196
+ # Select target queries (medical domain)
197
+ target_queries = []
198
+ for q in eval_data.queries:
199
+ q_lower = q.lower()
200
+ if any(kw in q_lower for kw in attack_config["medical_keywords"]):
201
+ target_queries.append(q)
202
+ if len(target_queries) >= attack_config["num_target_queries"]:
203
+ break
204
+
205
+ if len(target_queries) < 10:
206
+ # Fall back: use random queries if not enough medical ones
207
+ print(f"Only found {len(target_queries)} medical queries; "
208
+ f"using random queries to reach {attack_config['num_target_queries']}.")
209
+ remaining = attack_config["num_target_queries"] - len(target_queries)
210
+ other_queries = [q for q in eval_data.queries if q not in target_queries]
211
+ target_queries.extend(random.sample(other_queries, min(remaining, len(other_queries))))
212
+
213
+ print(f"Using {len(target_queries)} target queries for attack simulation.")
214
+
215
+ attack_results = run_attack_simulation(
216
+ encoder=encoder,
217
+ operator=operator,
218
+ baseline_operator=baseline_operator,
219
+ passages=eval_data.all_passages,
220
+ passage_embeddings_torch=passage_embeddings,
221
+ target_query_texts=target_queries,
222
+ medical_keywords=attack_config["medical_keywords"],
223
+ top_k=attack_config["top_k"],
224
+ device=device,
225
+ )
226
+ results["attack"] = attack_results
227
+
228
+ # =========================================================================
229
+ # Summary
230
+ # =========================================================================
231
+ print("\n" + "=" * 60)
232
+ print("EXPERIMENT SUMMARY")
233
+ print("=" * 60)
234
+
235
+ print(f"\n1. Retrieval Quality:")
236
+ print(f" Baseline MRR@10: {metrics_base['mrr@10']:.4f}")
237
+ print(f" Multi-Manifold MRR@10: {metrics_mm['mrr@10']:.4f}")
238
+ if metrics_base["mrr@10"] > 0:
239
+ print(f" Ratio: {metrics_mm['mrr@10']/metrics_base['mrr@10']:.4f}")
240
+
241
+ print(f"\n2. Spectral Analysis:")
242
+ print(f" Baseline δ: {spectral_results['baseline']['spectral_discrepancy']:.4f}")
243
+ print(f" Multi-Manifold δ: {spectral_results['multi_manifold']['spectral_discrepancy']:.4f}")
244
+ print(f" Baseline cos(θ): {spectral_results['baseline']['fiedler_alignment']:.4f}")
245
+ print(f" Multi-Manifold cos(θ): {spectral_results['multi_manifold']['fiedler_alignment']:.4f}")
246
+
247
+ if "error" not in attack_results:
248
+ print(f"\n3. Attack Simulation:")
249
+ print(f" Baseline ASR@{attack_config['top_k']}: {attack_results['baseline_asr']:.4f}")
250
+ print(f" Multi-Manifold ASR@{attack_config['top_k']}: {attack_results['multi_manifold_asr']:.4f}")
251
+
252
+ # Save results (exclude numpy arrays)
253
+ save_results = {k: v for k, v in results.items()
254
+ if k not in ("L_D", "L_R_mm", "L_R_base")}
255
+ with open(args.output, "w") as f:
256
+ json.dump(save_results, f, indent=2, default=str)
257
+ print(f"\nResults saved to {args.output}")
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()