bertose-affinose-training-code / code /probes /probe_14_retrieval.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
29.8 kB
#!/usr/bin/env python3
"""
Probe 14: Embedding-Based Nearest-Neighbor Retrieval
=====================================================
Evaluates whether nearest neighbors in GlycanBERT's embedding space share
biological properties. This demonstrates practical utility: given any glycan,
retrieve structurally and functionally similar glycans from a reference set.
Metrics:
- Precision@k (k=1, 5, 10, 20) for glycan type, domain, motif transfer
- Mean Average Precision (mAP) per property
- Motif transfer accuracy (majority vote of k-NN motifs)
- Comparison against glycowork structural fingerprint baseline
Usage:
python probe_14_retrieval.py # Both V5-A and V6
python probe_14_retrieval.py --model v5 # V5-A only
"""
import sys, os, json, argparse, warnings
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
warnings.filterwarnings('ignore', category=FutureWarning)
# ─── Project paths ────────────────────────────────────────────────────────────
PROJECT_ROOT = Path(__file__).resolve().parents[2]
VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json'
DATA_PATH = PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'glycowork_iupac_wurcs_unified.csv'
CHECKPOINTS = {
'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt',
'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt',
}
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
from downstream_tasks.utils.tokenizer import WURCSTokenizer
# ─── Nature-style plotting ────────────────────────────────────────────────────
plt.rcParams.update({
'font.family': 'sans-serif',
'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
'font.size': 10, 'axes.titlesize': 12, 'axes.labelsize': 11,
'xtick.labelsize': 9, 'ytick.labelsize': 9, 'legend.fontsize': 9,
'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False,
})
# ═════════════════════════════════════════════════════════════════════════════
# Model loading & embedding (identical to probes 8-13)
# ═════════════════════════════════════════════════════════════════════════════
def load_model(model_version, device):
ckpt_path = CHECKPOINTS[model_version]
print(f"Loading {model_version} from {ckpt_path}")
state = torch.load(ckpt_path, map_location='cpu', weights_only=False)
sd = state.get('model_state_dict', state)
if 'proj_head_state_dict' in state:
sd = {k: v for k, v in sd.items() if not k.startswith('proj_head')}
emb_weight = sd.get('seq_embeddings.token_embeddings.weight',
sd.get('token_embeddings.weight'))
vocab_size = emb_weight.shape[0] if emb_weight is not None else 2200
hidden = emb_weight.shape[1] if emb_weight is not None else 768
config = MultimodalGlycanBERTConfig(
seq_vocab_size=vocab_size, seq_hidden_size=hidden,
seq_num_layers=12, seq_num_heads=12, seq_max_length=256,
use_cnn_frontend=True, cnn_kernel_size=3,
)
model = MultimodalGlycanBERT(config)
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
print(f" Loaded: {sum(p.numel() for p in model.parameters()):,} params")
tokenizer = WURCSTokenizer(str(VOCAB_PATH))
return model, tokenizer
def get_cls_embeddings(model, tokenizer, wurcs_list, device,
batch_size=128, max_len=256):
all_embs, errors = [], 0
for i in range(0, len(wurcs_list), batch_size):
batch = wurcs_list[i:i+batch_size]
token_ids_list, bd_list, lt_list = [], [], []
for w in batch:
try:
tok = tokenizer.tokenize(w)
token_ids_list.append(tok['token_ids'][:max_len])
bd_list.append(tok['branch_depths'][:max_len])
lt_list.append(tok['linkage_types'][:max_len])
except Exception:
errors += 1
continue
if not token_ids_list:
continue
ml = max(len(x) for x in token_ids_list)
ids_t = torch.zeros(len(token_ids_list), ml, dtype=torch.long)
bd_t = torch.zeros_like(ids_t)
lt_t = torch.zeros_like(ids_t)
for j, (ids, bd, lt) in enumerate(zip(token_ids_list, bd_list, lt_list)):
ids_t[j, :len(ids)] = torch.tensor(ids, dtype=torch.long)
bd_t[j, :len(bd)] = torch.tensor(bd, dtype=torch.long)
lt_t[j, :len(lt)] = torch.tensor(lt, dtype=torch.long)
ids_t, bd_t, lt_t = ids_t.to(device), bd_t.to(device), lt_t.to(device)
with torch.no_grad():
seq_out = model.seq_embeddings(ids_t, branch_depths=bd_t,
linkage_types=lt_t)
all_embs.append(seq_out[:, 0, :].cpu().numpy())
if (i // batch_size) % 10 == 0:
print(f" Embedded {i+len(batch)}/{len(wurcs_list)} ({errors} errors)")
print(f" Total: {sum(e.shape[0] for e in all_embs):,} ({errors} errors)")
return np.vstack(all_embs) if all_embs else np.zeros((0, 768))
# ═════════════════════════════════════════════════════════════════════════════
# Retrieval metrics
# ═════════════════════════════════════════════════════════════════════════════
def compute_retrieval_metrics(embeddings, labels, label_name,
k_values=[1, 5, 10, 20]):
"""Compute Precision@k and mAP for a given set of labels.
Args:
embeddings: (N, D) array
labels: list of N labels (strings or ints)
label_name: name for reporting
k_values: list of k values
Returns:
dict with precision_at_k and mAP
"""
from sklearn.metrics.pairwise import cosine_similarity
label_arr = np.array(labels)
n = len(embeddings)
# Compute pairwise cosine similarity
print(f" Computing {n}x{n} cosine similarity matrix...")
sim_matrix = cosine_similarity(embeddings)
# Zero out self-similarity
np.fill_diagonal(sim_matrix, -1.0)
# Sort by similarity (descending) for each query
sorted_indices = np.argsort(-sim_matrix, axis=1)
# Precision@k
precision_at_k = {}
for k in k_values:
if k >= n:
k = n - 1
correct = 0
for i in range(n):
neighbors = sorted_indices[i, :k]
same_label = np.sum(label_arr[neighbors] == label_arr[i])
correct += same_label / k
precision_at_k[k] = correct / n
# Random baseline (expected precision under random retrieval)
label_counts = Counter(labels)
random_baseline = sum((c / n) ** 2 for c in label_counts.values())
# Mean Average Precision (mAP)
ap_sum = 0.0
for i in range(n):
query_label = label_arr[i]
# Number of relevant items (same label, excluding self)
n_relevant = label_counts[query_label] - 1
if n_relevant <= 0:
continue
hits = 0
precision_sum = 0.0
for rank, idx in enumerate(sorted_indices[i]):
if label_arr[idx] == query_label:
hits += 1
precision_sum += hits / (rank + 1)
if hits == n_relevant:
break
ap_sum += precision_sum / n_relevant
mAP = ap_sum / n
print(f"\n {label_name}:")
print(f" Random baseline: {random_baseline:.4f}")
for k, p in precision_at_k.items():
lift = p / random_baseline if random_baseline > 0 else 0
print(f" P@{k:>2d}: {p:.4f} (lift: {lift:.2f}x)")
print(f" mAP: {mAP:.4f}")
print(f" # classes: {len(label_counts)}, largest: "
f"{max(label_counts.values())}/{n} ({max(label_counts.values())/n*100:.1f}%)")
return {
'precision_at_k': {str(k): float(v) for k, v in precision_at_k.items()},
'mAP': float(mAP),
'random_baseline': float(random_baseline),
'n_classes': len(label_counts),
'n_samples': n,
'class_distribution': {str(k): int(v) for k, v in label_counts.items()},
}
def compute_motif_transfer(embeddings, motif_matrix, motif_names, k=5):
"""Predict each glycan's motifs via majority vote of k nearest neighbors.
Args:
embeddings: (N, D) array
motif_matrix: (N, M) binary array of motif presence
motif_names: list of M motif names
k: number of neighbors
Returns:
dict with per-motif transfer accuracy and F1
"""
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score
n = len(embeddings)
print(f"\n Motif transfer (k={k}, {len(motif_names)} motifs)...")
sim_matrix = cosine_similarity(embeddings)
np.fill_diagonal(sim_matrix, -1.0)
sorted_indices = np.argsort(-sim_matrix, axis=1)[:, :k]
results = {}
accuracies = []
f1s = []
for m_idx, mname in enumerate(motif_names):
true_labels = motif_matrix[:, m_idx]
n_pos = int(true_labels.sum())
if n_pos < 10 or n_pos > n - 10: # skip trivial motifs
continue
# Predict via majority vote of k-NN
predicted = np.zeros(n)
for i in range(n):
neighbor_motifs = motif_matrix[sorted_indices[i], m_idx]
predicted[i] = 1.0 if neighbor_motifs.mean() > 0.5 else 0.0
acc = np.mean(predicted == true_labels)
f1 = f1_score(true_labels, predicted, zero_division=0)
# Baseline: always predict majority class
baseline_acc = max(n_pos / n, 1.0 - n_pos / n)
accuracies.append(acc)
f1s.append(f1)
results[mname] = {
'accuracy': float(acc),
'baseline_accuracy': float(baseline_acc),
'f1': float(f1),
'n_positive': n_pos,
'lift': float(acc / baseline_acc) if baseline_acc > 0 else 0,
}
mean_acc = np.mean(accuracies) if accuracies else 0
mean_f1 = np.mean(f1s) if f1s else 0
print(f" Evaluated {len(results)} non-trivial motifs")
print(f" Mean transfer accuracy: {mean_acc:.4f}")
print(f" Mean transfer F1: {mean_f1:.4f}")
# Top and bottom motifs
if results:
sorted_motifs = sorted(results.items(), key=lambda x: x[1]['f1'], reverse=True)
print(f"\n Top 5 transferable motifs:")
for name, r in sorted_motifs[:5]:
print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} "
f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})")
print(f"\n Bottom 5 (hardest to transfer):")
for name, r in sorted_motifs[-5:]:
print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} "
f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})")
return {
'per_motif': results,
'mean_accuracy': float(mean_acc),
'mean_f1': float(mean_f1),
'n_motifs_evaluated': len(results),
}
def compute_fingerprint_baseline(annotated_df, motif_cols, labels, label_name,
k_values=[1, 5, 10, 20]):
"""Compute P@k using glycowork motif fingerprints as a non-ML baseline."""
from sklearn.metrics.pairwise import cosine_similarity
fingerprints = annotated_df[motif_cols].values.astype(float)
# Handle zero vectors
norms = np.linalg.norm(fingerprints, axis=1, keepdims=True)
norms[norms == 0] = 1.0
fingerprints = fingerprints / norms
label_arr = np.array(labels)
n = len(fingerprints)
sim_matrix = cosine_similarity(fingerprints)
np.fill_diagonal(sim_matrix, -1.0)
sorted_indices = np.argsort(-sim_matrix, axis=1)
precision_at_k = {}
for k in k_values:
if k >= n:
k = n - 1
correct = 0
for i in range(n):
neighbors = sorted_indices[i, :k]
same_label = np.sum(label_arr[neighbors] == label_arr[i])
correct += same_label / k
precision_at_k[k] = correct / n
print(f"\n Fingerprint baseline for {label_name}:")
for k, p in precision_at_k.items():
print(f" P@{k:>2d}: {p:.4f}")
return {str(k): float(v) for k, v in precision_at_k.items()}
# ═════════════════════════════════════════════════════════════════════════════
# Visualization
# ═════════════════════════════════════════════════════════════════════════════
def plot_precision_curves(results, fingerprint_results, output_path):
"""Plot P@k curves for BERT vs fingerprint baseline across properties."""
fig, ax = plt.subplots(figsize=(10, 6))
colors = {
'Glycan Type': '#0072B2',
'Domain': '#D55E00',
'Immunogenicity': '#009E73',
}
linestyles = {'bert': '-', 'fingerprint': '--'}
for prop_name, metrics in results.items():
if prop_name not in colors:
continue
color = colors[prop_name]
pk = metrics['precision_at_k']
ks = sorted([int(k) for k in pk.keys()])
vals = [pk[str(k)] for k in ks]
ax.plot(ks, vals, color=color, linestyle='-', marker='o',
linewidth=2, markersize=6, label=f'{prop_name} (BERT)')
# Fingerprint baseline
if prop_name in fingerprint_results:
fp = fingerprint_results[prop_name]
fp_vals = [fp[str(k)] for k in ks]
ax.plot(ks, fp_vals, color=color, linestyle='--', marker='s',
linewidth=1.5, markersize=5, alpha=0.7,
label=f'{prop_name} (fingerprint)')
# Random baseline
baseline = metrics['random_baseline']
ax.axhline(y=baseline, color=color, linestyle=':', alpha=0.3)
ax.set_xlabel('k (number of neighbors)')
ax.set_ylabel('Precision@k')
ax.set_title('Embedding Retrieval: BERT vs Structural Fingerprint',
fontsize=13, fontweight='bold')
ax.set_xticks(ks)
ax.legend(frameon=False, fontsize=8, loc='center right')
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white')
plt.savefig(str(output_path) + '.pdf', facecolor='white')
plt.close()
print(f" Saved: {output_path}.png")
def plot_retrieval_examples(sim_matrix, labels, df, output_path, n_examples=5):
"""Show example queries with their top-5 neighbors."""
n = len(labels)
np.random.seed(42)
query_indices = np.random.choice(n, min(n_examples, n), replace=False)
sorted_indices = np.argsort(-sim_matrix, axis=1)
fig, axes = plt.subplots(n_examples, 1, figsize=(14, 3 * n_examples))
if n_examples == 1:
axes = [axes]
label_arr = np.array(labels)
for i, qi in enumerate(query_indices):
ax = axes[i]
neighbors = sorted_indices[qi, :5]
sims = sim_matrix[qi, neighbors]
# Build display strings
query_type = label_arr[qi]
query_iupac = df.iloc[qi]['glycan'][:60] if 'glycan' in df.columns else '?'
neighbor_info = []
for ni, s in zip(neighbors, sims):
nt = label_arr[ni]
ni_iupac = df.iloc[ni]['glycan'][:40] if 'glycan' in df.columns else '?'
match = 'βœ“' if nt == query_type else 'βœ—'
neighbor_info.append(f"[{match}] sim={s:.3f} | {nt} | {ni_iupac}")
text = f"Query: [{query_type}] {query_iupac}\n"
text += '\n'.join([f" NN{j+1}: {info}" for j, info in enumerate(neighbor_info)])
ax.text(0.02, 0.5, text, transform=ax.transAxes, fontsize=8,
verticalalignment='center', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
plt.suptitle('Retrieval Examples: Query β†’ Top-5 Nearest Neighbors',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white')
plt.savefig(str(output_path) + '.pdf', facecolor='white')
plt.close()
print(f" Saved: {output_path}.png")
def plot_motif_transfer_bar(motif_results, output_path, top_n=15):
"""Bar chart showing motif transfer F1 scores."""
per_motif = motif_results['per_motif']
if not per_motif:
return
sorted_motifs = sorted(per_motif.items(), key=lambda x: x[1]['f1'], reverse=True)
top = sorted_motifs[:top_n]
names = [m[0][:25] for m in top]
f1s = [m[1]['f1'] for m in top]
baselines = [m[1]['baseline_accuracy'] for m in top]
fig, ax = plt.subplots(figsize=(10, 6))
y_pos = np.arange(len(names))
ax.barh(y_pos, f1s, height=0.6, color='#0072B2', alpha=0.8, label='BERT F1')
ax.barh(y_pos, baselines, height=0.3, color='#999999', alpha=0.5,
label='Majority baseline')
ax.set_yticks(y_pos)
ax.set_yticklabels(names, fontsize=8)
ax.set_xlabel('F1 Score')
ax.set_title(f'Motif Transfer via k-NN (top {top_n} motifs)',
fontsize=12, fontweight='bold')
ax.legend(frameon=False, loc='lower right')
ax.invert_yaxis()
plt.tight_layout()
plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white')
plt.savefig(str(output_path) + '.pdf', facecolor='white')
plt.close()
print(f" Saved: {output_path}.png")
# ═════════════════════════════════════════════════════════════════════════════
# Main
# ═════════════════════════════════════════════════════════════════════════════
def run_retrieval(model_version, device, max_glycans=15000):
model_name = {'v5': 'V5-A', 'v6': 'V6'}[model_version]
output_dir = (PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' /
'probe_results_v6' / 'probe_results_v6' /
'probe_14_retrieval' / model_version)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n{'='*70}")
print(f"Probe 14: Embedding Retrieval β€” GlycanBERT {model_name}")
print(f"{'='*70}")
# ── 1. Load data ──────────────────────────────────────────────────────
print(f"\n1. Loading data from {DATA_PATH}")
df = pd.read_csv(DATA_PATH)
mask = df['glycan'].notna() & df['wurcs'].notna()
df = df[mask].reset_index(drop=True)
print(f" {len(df)} glycans with IUPAC + WURCS")
# ── 2. Annotate motifs ────────────────────────────────────────────────
print(f"\n2. Annotating glycowork motifs...")
from glycowork.motif.annotate import annotate_dataset
iupac_list = df['glycan'].tolist()
try:
annotated = annotate_dataset(iupac_list, feature_set=['known'],
condense=True)
except TypeError:
annotated = annotate_dataset(iupac_list)
motif_cols = [c for c in annotated.columns if c != 'glycan']
print(f" Found {len(motif_cols)} motif columns")
# ── 3. Extract labels ─────────────────────────────────────────────────
print(f"\n3. Extracting labels...")
# Glycan type
glycan_types = None
if 'glycan_type' in df.columns:
glycan_types = df['glycan_type'].fillna('Unknown').tolist()
type_counts = Counter(glycan_types)
# Consolidate very rare types
glycan_types = ['Other' if type_counts[t] < 30 else t for t in glycan_types]
print(f" Glycan types: {dict(Counter(glycan_types))}")
# Domain
domains = None
if 'domain' in df.columns:
domains = df['domain'].fillna('Unknown').tolist()
print(f" Domains: {dict(Counter(domains))}")
# Immunogenicity
immunogenicity = None
if 'immunogenicity' in df.columns:
imm = df['immunogenicity']
if imm.notna().sum() > 100:
immunogenicity = imm.fillna(-1).astype(int).tolist()
# Only keep glycans with labels
print(f" Immunogenicity: {Counter(immunogenicity)}")
# ── 4. Load model & extract embeddings ────────────────────────────────
print(f"\n4. Loading model and extracting CLS embeddings...")
model, tokenizer = load_model(model_version, device)
embeddings = get_cls_embeddings(model, tokenizer, df['wurcs'].tolist(), device)
print(f" Embeddings: {embeddings.shape}")
# Free GPU
import gc
del model
torch.cuda.empty_cache()
gc.collect()
# Handle dimension mismatch (some glycans may fail tokenization)
n_emb = embeddings.shape[0]
if n_emb < len(df):
print(f" WARNING: {len(df) - n_emb} glycans failed tokenization, truncating")
df = df.head(n_emb)
annotated = annotated.head(n_emb)
if glycan_types: glycan_types = glycan_types[:n_emb]
if domains: domains = domains[:n_emb]
if immunogenicity: immunogenicity = immunogenicity[:n_emb]
# ── 5. BERT retrieval metrics ─────────────────────────────────────────
print(f"\n{'─'*50}")
print(f"5. BERT Embedding Retrieval")
print(f"{'─'*50}")
all_results = {'model': model_name}
# 5a. Glycan type
if glycan_types:
all_results['Glycan Type'] = compute_retrieval_metrics(
embeddings, glycan_types, 'Glycan Type')
# 5b. Domain
if domains:
# Filter to non-Unknown
domain_mask = [d != 'Unknown' for d in domains]
if sum(domain_mask) > 500:
d_embs = embeddings[domain_mask]
d_labels = [d for d, m in zip(domains, domain_mask) if m]
all_results['Domain'] = compute_retrieval_metrics(
d_embs, d_labels, 'Domain')
# 5c. Immunogenicity
if immunogenicity:
imm_mask = [i >= 0 for i in immunogenicity]
if sum(imm_mask) > 200:
i_embs = embeddings[imm_mask]
i_labels = [i for i, m in zip(immunogenicity, imm_mask) if m]
all_results['Immunogenicity'] = compute_retrieval_metrics(
i_embs, i_labels, 'Immunogenicity')
# ── 6. Motif transfer ─────────────────────────────────────────────────
print(f"\n{'─'*50}")
print(f"6. Motif Transfer via k-NN")
print(f"{'─'*50}")
motif_matrix = (annotated[motif_cols].values > 0).astype(float)
motif_transfer = compute_motif_transfer(embeddings, motif_matrix,
motif_cols, k=5)
all_results['motif_transfer'] = motif_transfer
# ── 7. Fingerprint baseline ───────────────────────────────────────────
print(f"\n{'─'*50}")
print(f"7. Glycowork Fingerprint Baseline")
print(f"{'─'*50}")
fingerprint_results = {}
if glycan_types:
fingerprint_results['Glycan Type'] = compute_fingerprint_baseline(
annotated, motif_cols, glycan_types, 'Glycan Type')
if domains:
domain_mask = [d != 'Unknown' for d in domains]
if sum(domain_mask) > 500:
d_annot = annotated[domain_mask]
d_labels = [d for d, m in zip(domains, domain_mask) if m]
fingerprint_results['Domain'] = compute_fingerprint_baseline(
d_annot, motif_cols, d_labels, 'Domain')
all_results['fingerprint_baseline'] = fingerprint_results
# ── 8. Visualizations ─────────────────────────────────────────────────
print(f"\n{'─'*50}")
print(f"8. Generating Visualizations")
print(f"{'─'*50}")
# P@k curves
plot_precision_curves(all_results, fingerprint_results,
output_dir / f'precision_at_k_{model_name.lower()}')
# Motif transfer bar chart
plot_motif_transfer_bar(motif_transfer,
output_dir / f'motif_transfer_{model_name.lower()}')
# Retrieval examples
from sklearn.metrics.pairwise import cosine_similarity
sim_mat = cosine_similarity(embeddings[:1000])
np.fill_diagonal(sim_mat, -1.0)
if glycan_types:
plot_retrieval_examples(sim_mat, glycan_types[:1000], df.head(1000),
output_dir / f'retrieval_examples_{model_name.lower()}')
# ── 9. Save results ───────────────────────────────────────────────────
results_path = output_dir / 'retrieval_results.json'
# Make JSON-serializable
def to_serializable(obj):
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj
import json
with open(results_path, 'w') as f:
json.dump(all_results, f, indent=2, default=to_serializable)
print(f"\n Results saved to: {results_path}")
# ── 10. Summary table ─────────────────────────────────────────────────
print(f"\n{'='*70}")
print(f"SUMMARY β€” Probe 14 Retrieval ({model_name})")
print(f"{'='*70}")
print(f"\n{'Property':<20} {'P@1':>8} {'P@5':>8} {'P@10':>8} {'P@20':>8} {'mAP':>8} {'Baseline':>10}")
print(f"{'─'*74}")
for prop in ['Glycan Type', 'Domain', 'Immunogenicity']:
if prop in all_results:
r = all_results[prop]
pk = r['precision_at_k']
print(f"{prop:<20} {pk.get('1', 0):>8.4f} {pk.get('5', 0):>8.4f} "
f"{pk.get('10', 0):>8.4f} {pk.get('20', 0):>8.4f} "
f"{r['mAP']:>8.4f} {r['random_baseline']:>10.4f}")
if fingerprint_results:
print(f"\n Fingerprint baselines:")
for prop, fp in fingerprint_results.items():
vals = ' '.join([f"P@{k}={v:.4f}" for k, v in sorted(fp.items(),
key=lambda x: int(x[0]))])
print(f" {prop}: {vals}")
print(f"\n Motif transfer (k=5): mean F1={motif_transfer['mean_f1']:.4f}, "
f"mean acc={motif_transfer['mean_accuracy']:.4f} "
f"({motif_transfer['n_motifs_evaluated']} motifs)")
return all_results
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=['v5', 'v6', 'both'], default='both')
parser.add_argument('--max_glycans', type=int, default=15000)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
models = ['v5', 'v6'] if args.model == 'both' else [args.model]
all_model_results = {}
for mv in models:
results = run_retrieval(mv, device, args.max_glycans)
all_model_results[mv] = results
# If both models, print comparison
if len(all_model_results) == 2:
print(f"\n{'='*70}")
print(f"COMPARISON: V5-A vs V6")
print(f"{'='*70}")
for prop in ['Glycan Type', 'Domain', 'Immunogenicity']:
v5_r = all_model_results.get('v5', {}).get(prop)
v6_r = all_model_results.get('v6', {}).get(prop)
if v5_r and v6_r:
v5_map = v5_r['mAP']
v6_map = v6_r['mAP']
print(f"\n {prop}:")
print(f" V5-A mAP: {v5_map:.4f}")
print(f" V6 mAP: {v6_map:.4f}")
print(f" Ξ”: {v6_map - v5_map:+.4f} ({'V6 better' if v6_map > v5_map else 'V5-A better'})")
print(f"\nProbe 14 complete!")
if __name__ == '__main__':
main()