#!/usr/bin/env python3 """ Probe 14: Embedding-Based Nearest-Neighbor Retrieval ===================================================== Evaluates whether nearest neighbors in GlycanBERT's embedding space share biological properties. This demonstrates practical utility: given any glycan, retrieve structurally and functionally similar glycans from a reference set. Metrics: - Precision@k (k=1, 5, 10, 20) for glycan type, domain, motif transfer - Mean Average Precision (mAP) per property - Motif transfer accuracy (majority vote of k-NN motifs) - Comparison against glycowork structural fingerprint baseline Usage: python probe_14_retrieval.py # Both V5-A and V6 python probe_14_retrieval.py --model v5 # V5-A only """ import sys, os, json, argparse, warnings import numpy as np import pandas as pd from pathlib import Path from collections import Counter import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import torch warnings.filterwarnings('ignore', category=FutureWarning) # ─── Project paths ──────────────────────────────────────────────────────────── PROJECT_ROOT = Path(__file__).resolve().parents[2] VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' DATA_PATH = PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'glycowork_iupac_wurcs_unified.csv' CHECKPOINTS = { 'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt', 'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt', } sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from downstream_tasks.utils.tokenizer import WURCSTokenizer # ─── Nature-style plotting ──────────────────────────────────────────────────── plt.rcParams.update({ 'font.family': 'sans-serif', 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'], 'font.size': 10, 'axes.titlesize': 12, 'axes.labelsize': 11, 'xtick.labelsize': 9, 'ytick.labelsize': 9, 'legend.fontsize': 9, 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', 'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False, }) # ═════════════════════════════════════════════════════════════════════════════ # Model loading & embedding (identical to probes 8-13) # ═════════════════════════════════════════════════════════════════════════════ def load_model(model_version, device): ckpt_path = CHECKPOINTS[model_version] print(f"Loading {model_version} from {ckpt_path}") state = torch.load(ckpt_path, map_location='cpu', weights_only=False) sd = state.get('model_state_dict', state) if 'proj_head_state_dict' in state: sd = {k: v for k, v in sd.items() if not k.startswith('proj_head')} emb_weight = sd.get('seq_embeddings.token_embeddings.weight', sd.get('token_embeddings.weight')) vocab_size = emb_weight.shape[0] if emb_weight is not None else 2200 hidden = emb_weight.shape[1] if emb_weight is not None else 768 config = MultimodalGlycanBERTConfig( seq_vocab_size=vocab_size, seq_hidden_size=hidden, seq_num_layers=12, seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, ) model = MultimodalGlycanBERT(config) model.load_state_dict(sd, strict=False) model = model.to(device).eval() print(f" Loaded: {sum(p.numel() for p in model.parameters()):,} params") tokenizer = WURCSTokenizer(str(VOCAB_PATH)) return model, tokenizer def get_cls_embeddings(model, tokenizer, wurcs_list, device, batch_size=128, max_len=256): all_embs, errors = [], 0 for i in range(0, len(wurcs_list), batch_size): batch = wurcs_list[i:i+batch_size] token_ids_list, bd_list, lt_list = [], [], [] for w in batch: try: tok = tokenizer.tokenize(w) token_ids_list.append(tok['token_ids'][:max_len]) bd_list.append(tok['branch_depths'][:max_len]) lt_list.append(tok['linkage_types'][:max_len]) except Exception: errors += 1 continue if not token_ids_list: continue ml = max(len(x) for x in token_ids_list) ids_t = torch.zeros(len(token_ids_list), ml, dtype=torch.long) bd_t = torch.zeros_like(ids_t) lt_t = torch.zeros_like(ids_t) for j, (ids, bd, lt) in enumerate(zip(token_ids_list, bd_list, lt_list)): ids_t[j, :len(ids)] = torch.tensor(ids, dtype=torch.long) bd_t[j, :len(bd)] = torch.tensor(bd, dtype=torch.long) lt_t[j, :len(lt)] = torch.tensor(lt, dtype=torch.long) ids_t, bd_t, lt_t = ids_t.to(device), bd_t.to(device), lt_t.to(device) with torch.no_grad(): seq_out = model.seq_embeddings(ids_t, branch_depths=bd_t, linkage_types=lt_t) all_embs.append(seq_out[:, 0, :].cpu().numpy()) if (i // batch_size) % 10 == 0: print(f" Embedded {i+len(batch)}/{len(wurcs_list)} ({errors} errors)") print(f" Total: {sum(e.shape[0] for e in all_embs):,} ({errors} errors)") return np.vstack(all_embs) if all_embs else np.zeros((0, 768)) # ═════════════════════════════════════════════════════════════════════════════ # Retrieval metrics # ═════════════════════════════════════════════════════════════════════════════ def compute_retrieval_metrics(embeddings, labels, label_name, k_values=[1, 5, 10, 20]): """Compute Precision@k and mAP for a given set of labels. Args: embeddings: (N, D) array labels: list of N labels (strings or ints) label_name: name for reporting k_values: list of k values Returns: dict with precision_at_k and mAP """ from sklearn.metrics.pairwise import cosine_similarity label_arr = np.array(labels) n = len(embeddings) # Compute pairwise cosine similarity print(f" Computing {n}x{n} cosine similarity matrix...") sim_matrix = cosine_similarity(embeddings) # Zero out self-similarity np.fill_diagonal(sim_matrix, -1.0) # Sort by similarity (descending) for each query sorted_indices = np.argsort(-sim_matrix, axis=1) # Precision@k precision_at_k = {} for k in k_values: if k >= n: k = n - 1 correct = 0 for i in range(n): neighbors = sorted_indices[i, :k] same_label = np.sum(label_arr[neighbors] == label_arr[i]) correct += same_label / k precision_at_k[k] = correct / n # Random baseline (expected precision under random retrieval) label_counts = Counter(labels) random_baseline = sum((c / n) ** 2 for c in label_counts.values()) # Mean Average Precision (mAP) ap_sum = 0.0 for i in range(n): query_label = label_arr[i] # Number of relevant items (same label, excluding self) n_relevant = label_counts[query_label] - 1 if n_relevant <= 0: continue hits = 0 precision_sum = 0.0 for rank, idx in enumerate(sorted_indices[i]): if label_arr[idx] == query_label: hits += 1 precision_sum += hits / (rank + 1) if hits == n_relevant: break ap_sum += precision_sum / n_relevant mAP = ap_sum / n print(f"\n {label_name}:") print(f" Random baseline: {random_baseline:.4f}") for k, p in precision_at_k.items(): lift = p / random_baseline if random_baseline > 0 else 0 print(f" P@{k:>2d}: {p:.4f} (lift: {lift:.2f}x)") print(f" mAP: {mAP:.4f}") print(f" # classes: {len(label_counts)}, largest: " f"{max(label_counts.values())}/{n} ({max(label_counts.values())/n*100:.1f}%)") return { 'precision_at_k': {str(k): float(v) for k, v in precision_at_k.items()}, 'mAP': float(mAP), 'random_baseline': float(random_baseline), 'n_classes': len(label_counts), 'n_samples': n, 'class_distribution': {str(k): int(v) for k, v in label_counts.items()}, } def compute_motif_transfer(embeddings, motif_matrix, motif_names, k=5): """Predict each glycan's motifs via majority vote of k nearest neighbors. Args: embeddings: (N, D) array motif_matrix: (N, M) binary array of motif presence motif_names: list of M motif names k: number of neighbors Returns: dict with per-motif transfer accuracy and F1 """ from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics import f1_score n = len(embeddings) print(f"\n Motif transfer (k={k}, {len(motif_names)} motifs)...") sim_matrix = cosine_similarity(embeddings) np.fill_diagonal(sim_matrix, -1.0) sorted_indices = np.argsort(-sim_matrix, axis=1)[:, :k] results = {} accuracies = [] f1s = [] for m_idx, mname in enumerate(motif_names): true_labels = motif_matrix[:, m_idx] n_pos = int(true_labels.sum()) if n_pos < 10 or n_pos > n - 10: # skip trivial motifs continue # Predict via majority vote of k-NN predicted = np.zeros(n) for i in range(n): neighbor_motifs = motif_matrix[sorted_indices[i], m_idx] predicted[i] = 1.0 if neighbor_motifs.mean() > 0.5 else 0.0 acc = np.mean(predicted == true_labels) f1 = f1_score(true_labels, predicted, zero_division=0) # Baseline: always predict majority class baseline_acc = max(n_pos / n, 1.0 - n_pos / n) accuracies.append(acc) f1s.append(f1) results[mname] = { 'accuracy': float(acc), 'baseline_accuracy': float(baseline_acc), 'f1': float(f1), 'n_positive': n_pos, 'lift': float(acc / baseline_acc) if baseline_acc > 0 else 0, } mean_acc = np.mean(accuracies) if accuracies else 0 mean_f1 = np.mean(f1s) if f1s else 0 print(f" Evaluated {len(results)} non-trivial motifs") print(f" Mean transfer accuracy: {mean_acc:.4f}") print(f" Mean transfer F1: {mean_f1:.4f}") # Top and bottom motifs if results: sorted_motifs = sorted(results.items(), key=lambda x: x[1]['f1'], reverse=True) print(f"\n Top 5 transferable motifs:") for name, r in sorted_motifs[:5]: print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} " f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})") print(f"\n Bottom 5 (hardest to transfer):") for name, r in sorted_motifs[-5:]: print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} " f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})") return { 'per_motif': results, 'mean_accuracy': float(mean_acc), 'mean_f1': float(mean_f1), 'n_motifs_evaluated': len(results), } def compute_fingerprint_baseline(annotated_df, motif_cols, labels, label_name, k_values=[1, 5, 10, 20]): """Compute P@k using glycowork motif fingerprints as a non-ML baseline.""" from sklearn.metrics.pairwise import cosine_similarity fingerprints = annotated_df[motif_cols].values.astype(float) # Handle zero vectors norms = np.linalg.norm(fingerprints, axis=1, keepdims=True) norms[norms == 0] = 1.0 fingerprints = fingerprints / norms label_arr = np.array(labels) n = len(fingerprints) sim_matrix = cosine_similarity(fingerprints) np.fill_diagonal(sim_matrix, -1.0) sorted_indices = np.argsort(-sim_matrix, axis=1) precision_at_k = {} for k in k_values: if k >= n: k = n - 1 correct = 0 for i in range(n): neighbors = sorted_indices[i, :k] same_label = np.sum(label_arr[neighbors] == label_arr[i]) correct += same_label / k precision_at_k[k] = correct / n print(f"\n Fingerprint baseline for {label_name}:") for k, p in precision_at_k.items(): print(f" P@{k:>2d}: {p:.4f}") return {str(k): float(v) for k, v in precision_at_k.items()} # ═════════════════════════════════════════════════════════════════════════════ # Visualization # ═════════════════════════════════════════════════════════════════════════════ def plot_precision_curves(results, fingerprint_results, output_path): """Plot P@k curves for BERT vs fingerprint baseline across properties.""" fig, ax = plt.subplots(figsize=(10, 6)) colors = { 'Glycan Type': '#0072B2', 'Domain': '#D55E00', 'Immunogenicity': '#009E73', } linestyles = {'bert': '-', 'fingerprint': '--'} for prop_name, metrics in results.items(): if prop_name not in colors: continue color = colors[prop_name] pk = metrics['precision_at_k'] ks = sorted([int(k) for k in pk.keys()]) vals = [pk[str(k)] for k in ks] ax.plot(ks, vals, color=color, linestyle='-', marker='o', linewidth=2, markersize=6, label=f'{prop_name} (BERT)') # Fingerprint baseline if prop_name in fingerprint_results: fp = fingerprint_results[prop_name] fp_vals = [fp[str(k)] for k in ks] ax.plot(ks, fp_vals, color=color, linestyle='--', marker='s', linewidth=1.5, markersize=5, alpha=0.7, label=f'{prop_name} (fingerprint)') # Random baseline baseline = metrics['random_baseline'] ax.axhline(y=baseline, color=color, linestyle=':', alpha=0.3) ax.set_xlabel('k (number of neighbors)') ax.set_ylabel('Precision@k') ax.set_title('Embedding Retrieval: BERT vs Structural Fingerprint', fontsize=13, fontweight='bold') ax.set_xticks(ks) ax.legend(frameon=False, fontsize=8, loc='center right') ax.set_ylim(0, 1.0) ax.grid(axis='y', alpha=0.2) plt.tight_layout() plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') plt.savefig(str(output_path) + '.pdf', facecolor='white') plt.close() print(f" Saved: {output_path}.png") def plot_retrieval_examples(sim_matrix, labels, df, output_path, n_examples=5): """Show example queries with their top-5 neighbors.""" n = len(labels) np.random.seed(42) query_indices = np.random.choice(n, min(n_examples, n), replace=False) sorted_indices = np.argsort(-sim_matrix, axis=1) fig, axes = plt.subplots(n_examples, 1, figsize=(14, 3 * n_examples)) if n_examples == 1: axes = [axes] label_arr = np.array(labels) for i, qi in enumerate(query_indices): ax = axes[i] neighbors = sorted_indices[qi, :5] sims = sim_matrix[qi, neighbors] # Build display strings query_type = label_arr[qi] query_iupac = df.iloc[qi]['glycan'][:60] if 'glycan' in df.columns else '?' neighbor_info = [] for ni, s in zip(neighbors, sims): nt = label_arr[ni] ni_iupac = df.iloc[ni]['glycan'][:40] if 'glycan' in df.columns else '?' match = '✓' if nt == query_type else '✗' neighbor_info.append(f"[{match}] sim={s:.3f} | {nt} | {ni_iupac}") text = f"Query: [{query_type}] {query_iupac}\n" text += '\n'.join([f" NN{j+1}: {info}" for j, info in enumerate(neighbor_info)]) ax.text(0.02, 0.5, text, transform=ax.transAxes, fontsize=8, verticalalignment='center', fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') plt.suptitle('Retrieval Examples: Query → Top-5 Nearest Neighbors', fontsize=13, fontweight='bold') plt.tight_layout() plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') plt.savefig(str(output_path) + '.pdf', facecolor='white') plt.close() print(f" Saved: {output_path}.png") def plot_motif_transfer_bar(motif_results, output_path, top_n=15): """Bar chart showing motif transfer F1 scores.""" per_motif = motif_results['per_motif'] if not per_motif: return sorted_motifs = sorted(per_motif.items(), key=lambda x: x[1]['f1'], reverse=True) top = sorted_motifs[:top_n] names = [m[0][:25] for m in top] f1s = [m[1]['f1'] for m in top] baselines = [m[1]['baseline_accuracy'] for m in top] fig, ax = plt.subplots(figsize=(10, 6)) y_pos = np.arange(len(names)) ax.barh(y_pos, f1s, height=0.6, color='#0072B2', alpha=0.8, label='BERT F1') ax.barh(y_pos, baselines, height=0.3, color='#999999', alpha=0.5, label='Majority baseline') ax.set_yticks(y_pos) ax.set_yticklabels(names, fontsize=8) ax.set_xlabel('F1 Score') ax.set_title(f'Motif Transfer via k-NN (top {top_n} motifs)', fontsize=12, fontweight='bold') ax.legend(frameon=False, loc='lower right') ax.invert_yaxis() plt.tight_layout() plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') plt.savefig(str(output_path) + '.pdf', facecolor='white') plt.close() print(f" Saved: {output_path}.png") # ═════════════════════════════════════════════════════════════════════════════ # Main # ═════════════════════════════════════════════════════════════════════════════ def run_retrieval(model_version, device, max_glycans=15000): model_name = {'v5': 'V5-A', 'v6': 'V6'}[model_version] output_dir = (PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'probe_results_v6' / 'probe_results_v6' / 'probe_14_retrieval' / model_version) output_dir.mkdir(parents=True, exist_ok=True) print(f"\n{'='*70}") print(f"Probe 14: Embedding Retrieval — GlycanBERT {model_name}") print(f"{'='*70}") # ── 1. Load data ────────────────────────────────────────────────────── print(f"\n1. Loading data from {DATA_PATH}") df = pd.read_csv(DATA_PATH) mask = df['glycan'].notna() & df['wurcs'].notna() df = df[mask].reset_index(drop=True) print(f" {len(df)} glycans with IUPAC + WURCS") # ── 2. Annotate motifs ──────────────────────────────────────────────── print(f"\n2. Annotating glycowork motifs...") from glycowork.motif.annotate import annotate_dataset iupac_list = df['glycan'].tolist() try: annotated = annotate_dataset(iupac_list, feature_set=['known'], condense=True) except TypeError: annotated = annotate_dataset(iupac_list) motif_cols = [c for c in annotated.columns if c != 'glycan'] print(f" Found {len(motif_cols)} motif columns") # ── 3. Extract labels ───────────────────────────────────────────────── print(f"\n3. Extracting labels...") # Glycan type glycan_types = None if 'glycan_type' in df.columns: glycan_types = df['glycan_type'].fillna('Unknown').tolist() type_counts = Counter(glycan_types) # Consolidate very rare types glycan_types = ['Other' if type_counts[t] < 30 else t for t in glycan_types] print(f" Glycan types: {dict(Counter(glycan_types))}") # Domain domains = None if 'domain' in df.columns: domains = df['domain'].fillna('Unknown').tolist() print(f" Domains: {dict(Counter(domains))}") # Immunogenicity immunogenicity = None if 'immunogenicity' in df.columns: imm = df['immunogenicity'] if imm.notna().sum() > 100: immunogenicity = imm.fillna(-1).astype(int).tolist() # Only keep glycans with labels print(f" Immunogenicity: {Counter(immunogenicity)}") # ── 4. Load model & extract embeddings ──────────────────────────────── print(f"\n4. Loading model and extracting CLS embeddings...") model, tokenizer = load_model(model_version, device) embeddings = get_cls_embeddings(model, tokenizer, df['wurcs'].tolist(), device) print(f" Embeddings: {embeddings.shape}") # Free GPU import gc del model torch.cuda.empty_cache() gc.collect() # Handle dimension mismatch (some glycans may fail tokenization) n_emb = embeddings.shape[0] if n_emb < len(df): print(f" WARNING: {len(df) - n_emb} glycans failed tokenization, truncating") df = df.head(n_emb) annotated = annotated.head(n_emb) if glycan_types: glycan_types = glycan_types[:n_emb] if domains: domains = domains[:n_emb] if immunogenicity: immunogenicity = immunogenicity[:n_emb] # ── 5. BERT retrieval metrics ───────────────────────────────────────── print(f"\n{'─'*50}") print(f"5. BERT Embedding Retrieval") print(f"{'─'*50}") all_results = {'model': model_name} # 5a. Glycan type if glycan_types: all_results['Glycan Type'] = compute_retrieval_metrics( embeddings, glycan_types, 'Glycan Type') # 5b. Domain if domains: # Filter to non-Unknown domain_mask = [d != 'Unknown' for d in domains] if sum(domain_mask) > 500: d_embs = embeddings[domain_mask] d_labels = [d for d, m in zip(domains, domain_mask) if m] all_results['Domain'] = compute_retrieval_metrics( d_embs, d_labels, 'Domain') # 5c. Immunogenicity if immunogenicity: imm_mask = [i >= 0 for i in immunogenicity] if sum(imm_mask) > 200: i_embs = embeddings[imm_mask] i_labels = [i for i, m in zip(immunogenicity, imm_mask) if m] all_results['Immunogenicity'] = compute_retrieval_metrics( i_embs, i_labels, 'Immunogenicity') # ── 6. Motif transfer ───────────────────────────────────────────────── print(f"\n{'─'*50}") print(f"6. Motif Transfer via k-NN") print(f"{'─'*50}") motif_matrix = (annotated[motif_cols].values > 0).astype(float) motif_transfer = compute_motif_transfer(embeddings, motif_matrix, motif_cols, k=5) all_results['motif_transfer'] = motif_transfer # ── 7. Fingerprint baseline ─────────────────────────────────────────── print(f"\n{'─'*50}") print(f"7. Glycowork Fingerprint Baseline") print(f"{'─'*50}") fingerprint_results = {} if glycan_types: fingerprint_results['Glycan Type'] = compute_fingerprint_baseline( annotated, motif_cols, glycan_types, 'Glycan Type') if domains: domain_mask = [d != 'Unknown' for d in domains] if sum(domain_mask) > 500: d_annot = annotated[domain_mask] d_labels = [d for d, m in zip(domains, domain_mask) if m] fingerprint_results['Domain'] = compute_fingerprint_baseline( d_annot, motif_cols, d_labels, 'Domain') all_results['fingerprint_baseline'] = fingerprint_results # ── 8. Visualizations ───────────────────────────────────────────────── print(f"\n{'─'*50}") print(f"8. Generating Visualizations") print(f"{'─'*50}") # P@k curves plot_precision_curves(all_results, fingerprint_results, output_dir / f'precision_at_k_{model_name.lower()}') # Motif transfer bar chart plot_motif_transfer_bar(motif_transfer, output_dir / f'motif_transfer_{model_name.lower()}') # Retrieval examples from sklearn.metrics.pairwise import cosine_similarity sim_mat = cosine_similarity(embeddings[:1000]) np.fill_diagonal(sim_mat, -1.0) if glycan_types: plot_retrieval_examples(sim_mat, glycan_types[:1000], df.head(1000), output_dir / f'retrieval_examples_{model_name.lower()}') # ── 9. Save results ─────────────────────────────────────────────────── results_path = output_dir / 'retrieval_results.json' # Make JSON-serializable def to_serializable(obj): if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() return obj import json with open(results_path, 'w') as f: json.dump(all_results, f, indent=2, default=to_serializable) print(f"\n Results saved to: {results_path}") # ── 10. Summary table ───────────────────────────────────────────────── print(f"\n{'='*70}") print(f"SUMMARY — Probe 14 Retrieval ({model_name})") print(f"{'='*70}") print(f"\n{'Property':<20} {'P@1':>8} {'P@5':>8} {'P@10':>8} {'P@20':>8} {'mAP':>8} {'Baseline':>10}") print(f"{'─'*74}") for prop in ['Glycan Type', 'Domain', 'Immunogenicity']: if prop in all_results: r = all_results[prop] pk = r['precision_at_k'] print(f"{prop:<20} {pk.get('1', 0):>8.4f} {pk.get('5', 0):>8.4f} " f"{pk.get('10', 0):>8.4f} {pk.get('20', 0):>8.4f} " f"{r['mAP']:>8.4f} {r['random_baseline']:>10.4f}") if fingerprint_results: print(f"\n Fingerprint baselines:") for prop, fp in fingerprint_results.items(): vals = ' '.join([f"P@{k}={v:.4f}" for k, v in sorted(fp.items(), key=lambda x: int(x[0]))]) print(f" {prop}: {vals}") print(f"\n Motif transfer (k=5): mean F1={motif_transfer['mean_f1']:.4f}, " f"mean acc={motif_transfer['mean_accuracy']:.4f} " f"({motif_transfer['n_motifs_evaluated']} motifs)") return all_results def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', choices=['v5', 'v6', 'both'], default='both') parser.add_argument('--max_glycans', type=int, default=15000) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") models = ['v5', 'v6'] if args.model == 'both' else [args.model] all_model_results = {} for mv in models: results = run_retrieval(mv, device, args.max_glycans) all_model_results[mv] = results # If both models, print comparison if len(all_model_results) == 2: print(f"\n{'='*70}") print(f"COMPARISON: V5-A vs V6") print(f"{'='*70}") for prop in ['Glycan Type', 'Domain', 'Immunogenicity']: v5_r = all_model_results.get('v5', {}).get(prop) v6_r = all_model_results.get('v6', {}).get(prop) if v5_r and v6_r: v5_map = v5_r['mAP'] v6_map = v6_r['mAP'] print(f"\n {prop}:") print(f" V5-A mAP: {v5_map:.4f}") print(f" V6 mAP: {v6_map:.4f}") print(f" Δ: {v6_map - v5_map:+.4f} ({'V6 better' if v6_map > v5_map else 'V5-A better'})") print(f"\nProbe 14 complete!") if __name__ == '__main__': main()