| |
| """ |
| Probe 14: Embedding-Based Nearest-Neighbor Retrieval |
| ===================================================== |
| Evaluates whether nearest neighbors in GlycanBERT's embedding space share |
| biological properties. This demonstrates practical utility: given any glycan, |
| retrieve structurally and functionally similar glycans from a reference set. |
| |
| Metrics: |
| - Precision@k (k=1, 5, 10, 20) for glycan type, domain, motif transfer |
| - Mean Average Precision (mAP) per property |
| - Motif transfer accuracy (majority vote of k-NN motifs) |
| - Comparison against glycowork structural fingerprint baseline |
| |
| Usage: |
| python probe_14_retrieval.py # Both V5-A and V6 |
| python probe_14_retrieval.py --model v5 # V5-A only |
| """ |
|
|
| import sys, os, json, argparse, warnings |
| import numpy as np |
| import pandas as pd |
| from pathlib import Path |
| from collections import Counter |
|
|
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import torch |
|
|
| warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' |
| DATA_PATH = PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'glycowork_iupac_wurcs_unified.csv' |
|
|
| CHECKPOINTS = { |
| 'v5': PROJECT_ROOT / 'checkpoints_v5_bpe_topo' / 'best_v5_bpe_topo_model.pt', |
| 'v6': PROJECT_ROOT / 'bert_v5.1_contrastive' / 'checkpoints' / 'best_v51_contrastive_model.pt', |
| } |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) |
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from downstream_tasks.utils.tokenizer import WURCSTokenizer |
|
|
| |
| plt.rcParams.update({ |
| 'font.family': 'sans-serif', |
| 'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'], |
| 'font.size': 10, 'axes.titlesize': 12, 'axes.labelsize': 11, |
| 'xtick.labelsize': 9, 'ytick.labelsize': 9, 'legend.fontsize': 9, |
| 'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight', |
| 'axes.linewidth': 0.8, 'axes.spines.top': False, 'axes.spines.right': False, |
| }) |
|
|
| |
| |
| |
|
|
| def load_model(model_version, device): |
| ckpt_path = CHECKPOINTS[model_version] |
| print(f"Loading {model_version} from {ckpt_path}") |
| state = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| sd = state.get('model_state_dict', state) |
| if 'proj_head_state_dict' in state: |
| sd = {k: v for k, v in sd.items() if not k.startswith('proj_head')} |
| emb_weight = sd.get('seq_embeddings.token_embeddings.weight', |
| sd.get('token_embeddings.weight')) |
| vocab_size = emb_weight.shape[0] if emb_weight is not None else 2200 |
| hidden = emb_weight.shape[1] if emb_weight is not None else 768 |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=vocab_size, seq_hidden_size=hidden, |
| seq_num_layers=12, seq_num_heads=12, seq_max_length=256, |
| use_cnn_frontend=True, cnn_kernel_size=3, |
| ) |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(sd, strict=False) |
| model = model.to(device).eval() |
| print(f" Loaded: {sum(p.numel() for p in model.parameters()):,} params") |
| tokenizer = WURCSTokenizer(str(VOCAB_PATH)) |
| return model, tokenizer |
|
|
|
|
| def get_cls_embeddings(model, tokenizer, wurcs_list, device, |
| batch_size=128, max_len=256): |
| all_embs, errors = [], 0 |
| for i in range(0, len(wurcs_list), batch_size): |
| batch = wurcs_list[i:i+batch_size] |
| token_ids_list, bd_list, lt_list = [], [], [] |
| for w in batch: |
| try: |
| tok = tokenizer.tokenize(w) |
| token_ids_list.append(tok['token_ids'][:max_len]) |
| bd_list.append(tok['branch_depths'][:max_len]) |
| lt_list.append(tok['linkage_types'][:max_len]) |
| except Exception: |
| errors += 1 |
| continue |
| if not token_ids_list: |
| continue |
| ml = max(len(x) for x in token_ids_list) |
| ids_t = torch.zeros(len(token_ids_list), ml, dtype=torch.long) |
| bd_t = torch.zeros_like(ids_t) |
| lt_t = torch.zeros_like(ids_t) |
| for j, (ids, bd, lt) in enumerate(zip(token_ids_list, bd_list, lt_list)): |
| ids_t[j, :len(ids)] = torch.tensor(ids, dtype=torch.long) |
| bd_t[j, :len(bd)] = torch.tensor(bd, dtype=torch.long) |
| lt_t[j, :len(lt)] = torch.tensor(lt, dtype=torch.long) |
| ids_t, bd_t, lt_t = ids_t.to(device), bd_t.to(device), lt_t.to(device) |
| with torch.no_grad(): |
| seq_out = model.seq_embeddings(ids_t, branch_depths=bd_t, |
| linkage_types=lt_t) |
| all_embs.append(seq_out[:, 0, :].cpu().numpy()) |
| if (i // batch_size) % 10 == 0: |
| print(f" Embedded {i+len(batch)}/{len(wurcs_list)} ({errors} errors)") |
| print(f" Total: {sum(e.shape[0] for e in all_embs):,} ({errors} errors)") |
| return np.vstack(all_embs) if all_embs else np.zeros((0, 768)) |
|
|
|
|
| |
| |
| |
|
|
| def compute_retrieval_metrics(embeddings, labels, label_name, |
| k_values=[1, 5, 10, 20]): |
| """Compute Precision@k and mAP for a given set of labels. |
| |
| Args: |
| embeddings: (N, D) array |
| labels: list of N labels (strings or ints) |
| label_name: name for reporting |
| k_values: list of k values |
| |
| Returns: |
| dict with precision_at_k and mAP |
| """ |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| label_arr = np.array(labels) |
| n = len(embeddings) |
|
|
| |
| print(f" Computing {n}x{n} cosine similarity matrix...") |
| sim_matrix = cosine_similarity(embeddings) |
|
|
| |
| np.fill_diagonal(sim_matrix, -1.0) |
|
|
| |
| sorted_indices = np.argsort(-sim_matrix, axis=1) |
|
|
| |
| precision_at_k = {} |
| for k in k_values: |
| if k >= n: |
| k = n - 1 |
| correct = 0 |
| for i in range(n): |
| neighbors = sorted_indices[i, :k] |
| same_label = np.sum(label_arr[neighbors] == label_arr[i]) |
| correct += same_label / k |
| precision_at_k[k] = correct / n |
|
|
| |
| label_counts = Counter(labels) |
| random_baseline = sum((c / n) ** 2 for c in label_counts.values()) |
|
|
| |
| ap_sum = 0.0 |
| for i in range(n): |
| query_label = label_arr[i] |
| |
| n_relevant = label_counts[query_label] - 1 |
| if n_relevant <= 0: |
| continue |
|
|
| hits = 0 |
| precision_sum = 0.0 |
| for rank, idx in enumerate(sorted_indices[i]): |
| if label_arr[idx] == query_label: |
| hits += 1 |
| precision_sum += hits / (rank + 1) |
| if hits == n_relevant: |
| break |
| ap_sum += precision_sum / n_relevant |
|
|
| mAP = ap_sum / n |
|
|
| print(f"\n {label_name}:") |
| print(f" Random baseline: {random_baseline:.4f}") |
| for k, p in precision_at_k.items(): |
| lift = p / random_baseline if random_baseline > 0 else 0 |
| print(f" P@{k:>2d}: {p:.4f} (lift: {lift:.2f}x)") |
| print(f" mAP: {mAP:.4f}") |
| print(f" # classes: {len(label_counts)}, largest: " |
| f"{max(label_counts.values())}/{n} ({max(label_counts.values())/n*100:.1f}%)") |
|
|
| return { |
| 'precision_at_k': {str(k): float(v) for k, v in precision_at_k.items()}, |
| 'mAP': float(mAP), |
| 'random_baseline': float(random_baseline), |
| 'n_classes': len(label_counts), |
| 'n_samples': n, |
| 'class_distribution': {str(k): int(v) for k, v in label_counts.items()}, |
| } |
|
|
|
|
| def compute_motif_transfer(embeddings, motif_matrix, motif_names, k=5): |
| """Predict each glycan's motifs via majority vote of k nearest neighbors. |
| |
| Args: |
| embeddings: (N, D) array |
| motif_matrix: (N, M) binary array of motif presence |
| motif_names: list of M motif names |
| k: number of neighbors |
| |
| Returns: |
| dict with per-motif transfer accuracy and F1 |
| """ |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.metrics import f1_score |
|
|
| n = len(embeddings) |
| print(f"\n Motif transfer (k={k}, {len(motif_names)} motifs)...") |
|
|
| sim_matrix = cosine_similarity(embeddings) |
| np.fill_diagonal(sim_matrix, -1.0) |
| sorted_indices = np.argsort(-sim_matrix, axis=1)[:, :k] |
|
|
| results = {} |
| accuracies = [] |
| f1s = [] |
|
|
| for m_idx, mname in enumerate(motif_names): |
| true_labels = motif_matrix[:, m_idx] |
| n_pos = int(true_labels.sum()) |
| if n_pos < 10 or n_pos > n - 10: |
| continue |
|
|
| |
| predicted = np.zeros(n) |
| for i in range(n): |
| neighbor_motifs = motif_matrix[sorted_indices[i], m_idx] |
| predicted[i] = 1.0 if neighbor_motifs.mean() > 0.5 else 0.0 |
|
|
| acc = np.mean(predicted == true_labels) |
| f1 = f1_score(true_labels, predicted, zero_division=0) |
|
|
| |
| baseline_acc = max(n_pos / n, 1.0 - n_pos / n) |
|
|
| accuracies.append(acc) |
| f1s.append(f1) |
| results[mname] = { |
| 'accuracy': float(acc), |
| 'baseline_accuracy': float(baseline_acc), |
| 'f1': float(f1), |
| 'n_positive': n_pos, |
| 'lift': float(acc / baseline_acc) if baseline_acc > 0 else 0, |
| } |
|
|
| mean_acc = np.mean(accuracies) if accuracies else 0 |
| mean_f1 = np.mean(f1s) if f1s else 0 |
| print(f" Evaluated {len(results)} non-trivial motifs") |
| print(f" Mean transfer accuracy: {mean_acc:.4f}") |
| print(f" Mean transfer F1: {mean_f1:.4f}") |
|
|
| |
| if results: |
| sorted_motifs = sorted(results.items(), key=lambda x: x[1]['f1'], reverse=True) |
| print(f"\n Top 5 transferable motifs:") |
| for name, r in sorted_motifs[:5]: |
| print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} " |
| f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})") |
| print(f"\n Bottom 5 (hardest to transfer):") |
| for name, r in sorted_motifs[-5:]: |
| print(f" {name}: F1={r['f1']:.3f}, acc={r['accuracy']:.3f} " |
| f"(baseline={r['baseline_accuracy']:.3f}, n_pos={r['n_positive']})") |
|
|
| return { |
| 'per_motif': results, |
| 'mean_accuracy': float(mean_acc), |
| 'mean_f1': float(mean_f1), |
| 'n_motifs_evaluated': len(results), |
| } |
|
|
|
|
| def compute_fingerprint_baseline(annotated_df, motif_cols, labels, label_name, |
| k_values=[1, 5, 10, 20]): |
| """Compute P@k using glycowork motif fingerprints as a non-ML baseline.""" |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| fingerprints = annotated_df[motif_cols].values.astype(float) |
| |
| norms = np.linalg.norm(fingerprints, axis=1, keepdims=True) |
| norms[norms == 0] = 1.0 |
| fingerprints = fingerprints / norms |
|
|
| label_arr = np.array(labels) |
| n = len(fingerprints) |
|
|
| sim_matrix = cosine_similarity(fingerprints) |
| np.fill_diagonal(sim_matrix, -1.0) |
| sorted_indices = np.argsort(-sim_matrix, axis=1) |
|
|
| precision_at_k = {} |
| for k in k_values: |
| if k >= n: |
| k = n - 1 |
| correct = 0 |
| for i in range(n): |
| neighbors = sorted_indices[i, :k] |
| same_label = np.sum(label_arr[neighbors] == label_arr[i]) |
| correct += same_label / k |
| precision_at_k[k] = correct / n |
|
|
| print(f"\n Fingerprint baseline for {label_name}:") |
| for k, p in precision_at_k.items(): |
| print(f" P@{k:>2d}: {p:.4f}") |
|
|
| return {str(k): float(v) for k, v in precision_at_k.items()} |
|
|
|
|
| |
| |
| |
|
|
| def plot_precision_curves(results, fingerprint_results, output_path): |
| """Plot P@k curves for BERT vs fingerprint baseline across properties.""" |
| fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
| colors = { |
| 'Glycan Type': '#0072B2', |
| 'Domain': '#D55E00', |
| 'Immunogenicity': '#009E73', |
| } |
| linestyles = {'bert': '-', 'fingerprint': '--'} |
|
|
| for prop_name, metrics in results.items(): |
| if prop_name not in colors: |
| continue |
| color = colors[prop_name] |
| pk = metrics['precision_at_k'] |
| ks = sorted([int(k) for k in pk.keys()]) |
| vals = [pk[str(k)] for k in ks] |
| ax.plot(ks, vals, color=color, linestyle='-', marker='o', |
| linewidth=2, markersize=6, label=f'{prop_name} (BERT)') |
|
|
| |
| if prop_name in fingerprint_results: |
| fp = fingerprint_results[prop_name] |
| fp_vals = [fp[str(k)] for k in ks] |
| ax.plot(ks, fp_vals, color=color, linestyle='--', marker='s', |
| linewidth=1.5, markersize=5, alpha=0.7, |
| label=f'{prop_name} (fingerprint)') |
|
|
| |
| baseline = metrics['random_baseline'] |
| ax.axhline(y=baseline, color=color, linestyle=':', alpha=0.3) |
|
|
| ax.set_xlabel('k (number of neighbors)') |
| ax.set_ylabel('Precision@k') |
| ax.set_title('Embedding Retrieval: BERT vs Structural Fingerprint', |
| fontsize=13, fontweight='bold') |
| ax.set_xticks(ks) |
| ax.legend(frameon=False, fontsize=8, loc='center right') |
| ax.set_ylim(0, 1.0) |
| ax.grid(axis='y', alpha=0.2) |
|
|
| plt.tight_layout() |
| plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') |
| plt.savefig(str(output_path) + '.pdf', facecolor='white') |
| plt.close() |
| print(f" Saved: {output_path}.png") |
|
|
|
|
| def plot_retrieval_examples(sim_matrix, labels, df, output_path, n_examples=5): |
| """Show example queries with their top-5 neighbors.""" |
| n = len(labels) |
| np.random.seed(42) |
| query_indices = np.random.choice(n, min(n_examples, n), replace=False) |
|
|
| sorted_indices = np.argsort(-sim_matrix, axis=1) |
|
|
| fig, axes = plt.subplots(n_examples, 1, figsize=(14, 3 * n_examples)) |
| if n_examples == 1: |
| axes = [axes] |
|
|
| label_arr = np.array(labels) |
|
|
| for i, qi in enumerate(query_indices): |
| ax = axes[i] |
| neighbors = sorted_indices[qi, :5] |
| sims = sim_matrix[qi, neighbors] |
|
|
| |
| query_type = label_arr[qi] |
| query_iupac = df.iloc[qi]['glycan'][:60] if 'glycan' in df.columns else '?' |
|
|
| neighbor_info = [] |
| for ni, s in zip(neighbors, sims): |
| nt = label_arr[ni] |
| ni_iupac = df.iloc[ni]['glycan'][:40] if 'glycan' in df.columns else '?' |
| match = 'β' if nt == query_type else 'β' |
| neighbor_info.append(f"[{match}] sim={s:.3f} | {nt} | {ni_iupac}") |
|
|
| text = f"Query: [{query_type}] {query_iupac}\n" |
| text += '\n'.join([f" NN{j+1}: {info}" for j, info in enumerate(neighbor_info)]) |
|
|
| ax.text(0.02, 0.5, text, transform=ax.transAxes, fontsize=8, |
| verticalalignment='center', fontfamily='monospace', |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| ax.axis('off') |
|
|
| plt.suptitle('Retrieval Examples: Query β Top-5 Nearest Neighbors', |
| fontsize=13, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') |
| plt.savefig(str(output_path) + '.pdf', facecolor='white') |
| plt.close() |
| print(f" Saved: {output_path}.png") |
|
|
|
|
| def plot_motif_transfer_bar(motif_results, output_path, top_n=15): |
| """Bar chart showing motif transfer F1 scores.""" |
| per_motif = motif_results['per_motif'] |
| if not per_motif: |
| return |
|
|
| sorted_motifs = sorted(per_motif.items(), key=lambda x: x[1]['f1'], reverse=True) |
| top = sorted_motifs[:top_n] |
|
|
| names = [m[0][:25] for m in top] |
| f1s = [m[1]['f1'] for m in top] |
| baselines = [m[1]['baseline_accuracy'] for m in top] |
|
|
| fig, ax = plt.subplots(figsize=(10, 6)) |
| y_pos = np.arange(len(names)) |
|
|
| ax.barh(y_pos, f1s, height=0.6, color='#0072B2', alpha=0.8, label='BERT F1') |
| ax.barh(y_pos, baselines, height=0.3, color='#999999', alpha=0.5, |
| label='Majority baseline') |
|
|
| ax.set_yticks(y_pos) |
| ax.set_yticklabels(names, fontsize=8) |
| ax.set_xlabel('F1 Score') |
| ax.set_title(f'Motif Transfer via k-NN (top {top_n} motifs)', |
| fontsize=12, fontweight='bold') |
| ax.legend(frameon=False, loc='lower right') |
| ax.invert_yaxis() |
|
|
| plt.tight_layout() |
| plt.savefig(str(output_path) + '.png', dpi=300, facecolor='white') |
| plt.savefig(str(output_path) + '.pdf', facecolor='white') |
| plt.close() |
| print(f" Saved: {output_path}.png") |
|
|
|
|
| |
| |
| |
|
|
| def run_retrieval(model_version, device, max_glycans=15000): |
| model_name = {'v5': 'V5-A', 'v6': 'V6'}[model_version] |
| output_dir = (PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / |
| 'probe_results_v6' / 'probe_results_v6' / |
| 'probe_14_retrieval' / model_version) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n{'='*70}") |
| print(f"Probe 14: Embedding Retrieval β GlycanBERT {model_name}") |
| print(f"{'='*70}") |
|
|
| |
| print(f"\n1. Loading data from {DATA_PATH}") |
| df = pd.read_csv(DATA_PATH) |
| mask = df['glycan'].notna() & df['wurcs'].notna() |
| df = df[mask].reset_index(drop=True) |
| print(f" {len(df)} glycans with IUPAC + WURCS") |
|
|
| |
| print(f"\n2. Annotating glycowork motifs...") |
| from glycowork.motif.annotate import annotate_dataset |
| iupac_list = df['glycan'].tolist() |
|
|
| try: |
| annotated = annotate_dataset(iupac_list, feature_set=['known'], |
| condense=True) |
| except TypeError: |
| annotated = annotate_dataset(iupac_list) |
|
|
| motif_cols = [c for c in annotated.columns if c != 'glycan'] |
| print(f" Found {len(motif_cols)} motif columns") |
|
|
| |
| print(f"\n3. Extracting labels...") |
|
|
| |
| glycan_types = None |
| if 'glycan_type' in df.columns: |
| glycan_types = df['glycan_type'].fillna('Unknown').tolist() |
| type_counts = Counter(glycan_types) |
| |
| glycan_types = ['Other' if type_counts[t] < 30 else t for t in glycan_types] |
| print(f" Glycan types: {dict(Counter(glycan_types))}") |
|
|
| |
| domains = None |
| if 'domain' in df.columns: |
| domains = df['domain'].fillna('Unknown').tolist() |
| print(f" Domains: {dict(Counter(domains))}") |
|
|
| |
| immunogenicity = None |
| if 'immunogenicity' in df.columns: |
| imm = df['immunogenicity'] |
| if imm.notna().sum() > 100: |
| immunogenicity = imm.fillna(-1).astype(int).tolist() |
| |
| print(f" Immunogenicity: {Counter(immunogenicity)}") |
|
|
| |
| print(f"\n4. Loading model and extracting CLS embeddings...") |
| model, tokenizer = load_model(model_version, device) |
| embeddings = get_cls_embeddings(model, tokenizer, df['wurcs'].tolist(), device) |
| print(f" Embeddings: {embeddings.shape}") |
|
|
| |
| import gc |
| del model |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| n_emb = embeddings.shape[0] |
| if n_emb < len(df): |
| print(f" WARNING: {len(df) - n_emb} glycans failed tokenization, truncating") |
| df = df.head(n_emb) |
| annotated = annotated.head(n_emb) |
| if glycan_types: glycan_types = glycan_types[:n_emb] |
| if domains: domains = domains[:n_emb] |
| if immunogenicity: immunogenicity = immunogenicity[:n_emb] |
|
|
| |
| print(f"\n{'β'*50}") |
| print(f"5. BERT Embedding Retrieval") |
| print(f"{'β'*50}") |
|
|
| all_results = {'model': model_name} |
|
|
| |
| if glycan_types: |
| all_results['Glycan Type'] = compute_retrieval_metrics( |
| embeddings, glycan_types, 'Glycan Type') |
|
|
| |
| if domains: |
| |
| domain_mask = [d != 'Unknown' for d in domains] |
| if sum(domain_mask) > 500: |
| d_embs = embeddings[domain_mask] |
| d_labels = [d for d, m in zip(domains, domain_mask) if m] |
| all_results['Domain'] = compute_retrieval_metrics( |
| d_embs, d_labels, 'Domain') |
|
|
| |
| if immunogenicity: |
| imm_mask = [i >= 0 for i in immunogenicity] |
| if sum(imm_mask) > 200: |
| i_embs = embeddings[imm_mask] |
| i_labels = [i for i, m in zip(immunogenicity, imm_mask) if m] |
| all_results['Immunogenicity'] = compute_retrieval_metrics( |
| i_embs, i_labels, 'Immunogenicity') |
|
|
| |
| print(f"\n{'β'*50}") |
| print(f"6. Motif Transfer via k-NN") |
| print(f"{'β'*50}") |
|
|
| motif_matrix = (annotated[motif_cols].values > 0).astype(float) |
| motif_transfer = compute_motif_transfer(embeddings, motif_matrix, |
| motif_cols, k=5) |
| all_results['motif_transfer'] = motif_transfer |
|
|
| |
| print(f"\n{'β'*50}") |
| print(f"7. Glycowork Fingerprint Baseline") |
| print(f"{'β'*50}") |
|
|
| fingerprint_results = {} |
| if glycan_types: |
| fingerprint_results['Glycan Type'] = compute_fingerprint_baseline( |
| annotated, motif_cols, glycan_types, 'Glycan Type') |
| if domains: |
| domain_mask = [d != 'Unknown' for d in domains] |
| if sum(domain_mask) > 500: |
| d_annot = annotated[domain_mask] |
| d_labels = [d for d, m in zip(domains, domain_mask) if m] |
| fingerprint_results['Domain'] = compute_fingerprint_baseline( |
| d_annot, motif_cols, d_labels, 'Domain') |
|
|
| all_results['fingerprint_baseline'] = fingerprint_results |
|
|
| |
| print(f"\n{'β'*50}") |
| print(f"8. Generating Visualizations") |
| print(f"{'β'*50}") |
|
|
| |
| plot_precision_curves(all_results, fingerprint_results, |
| output_dir / f'precision_at_k_{model_name.lower()}') |
|
|
| |
| plot_motif_transfer_bar(motif_transfer, |
| output_dir / f'motif_transfer_{model_name.lower()}') |
|
|
| |
| from sklearn.metrics.pairwise import cosine_similarity |
| sim_mat = cosine_similarity(embeddings[:1000]) |
| np.fill_diagonal(sim_mat, -1.0) |
| if glycan_types: |
| plot_retrieval_examples(sim_mat, glycan_types[:1000], df.head(1000), |
| output_dir / f'retrieval_examples_{model_name.lower()}') |
|
|
| |
| results_path = output_dir / 'retrieval_results.json' |
|
|
| |
| def to_serializable(obj): |
| if isinstance(obj, (np.integer,)): |
| return int(obj) |
| if isinstance(obj, (np.floating,)): |
| return float(obj) |
| if isinstance(obj, np.ndarray): |
| return obj.tolist() |
| return obj |
|
|
| import json |
| with open(results_path, 'w') as f: |
| json.dump(all_results, f, indent=2, default=to_serializable) |
| print(f"\n Results saved to: {results_path}") |
|
|
| |
| print(f"\n{'='*70}") |
| print(f"SUMMARY β Probe 14 Retrieval ({model_name})") |
| print(f"{'='*70}") |
|
|
| print(f"\n{'Property':<20} {'P@1':>8} {'P@5':>8} {'P@10':>8} {'P@20':>8} {'mAP':>8} {'Baseline':>10}") |
| print(f"{'β'*74}") |
| for prop in ['Glycan Type', 'Domain', 'Immunogenicity']: |
| if prop in all_results: |
| r = all_results[prop] |
| pk = r['precision_at_k'] |
| print(f"{prop:<20} {pk.get('1', 0):>8.4f} {pk.get('5', 0):>8.4f} " |
| f"{pk.get('10', 0):>8.4f} {pk.get('20', 0):>8.4f} " |
| f"{r['mAP']:>8.4f} {r['random_baseline']:>10.4f}") |
|
|
| if fingerprint_results: |
| print(f"\n Fingerprint baselines:") |
| for prop, fp in fingerprint_results.items(): |
| vals = ' '.join([f"P@{k}={v:.4f}" for k, v in sorted(fp.items(), |
| key=lambda x: int(x[0]))]) |
| print(f" {prop}: {vals}") |
|
|
| print(f"\n Motif transfer (k=5): mean F1={motif_transfer['mean_f1']:.4f}, " |
| f"mean acc={motif_transfer['mean_accuracy']:.4f} " |
| f"({motif_transfer['n_motifs_evaluated']} motifs)") |
|
|
| return all_results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model', choices=['v5', 'v6', 'both'], default='both') |
| parser.add_argument('--max_glycans', type=int, default=15000) |
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| models = ['v5', 'v6'] if args.model == 'both' else [args.model] |
| all_model_results = {} |
|
|
| for mv in models: |
| results = run_retrieval(mv, device, args.max_glycans) |
| all_model_results[mv] = results |
|
|
| |
| if len(all_model_results) == 2: |
| print(f"\n{'='*70}") |
| print(f"COMPARISON: V5-A vs V6") |
| print(f"{'='*70}") |
| for prop in ['Glycan Type', 'Domain', 'Immunogenicity']: |
| v5_r = all_model_results.get('v5', {}).get(prop) |
| v6_r = all_model_results.get('v6', {}).get(prop) |
| if v5_r and v6_r: |
| v5_map = v5_r['mAP'] |
| v6_map = v6_r['mAP'] |
| print(f"\n {prop}:") |
| print(f" V5-A mAP: {v5_map:.4f}") |
| print(f" V6 mAP: {v6_map:.4f}") |
| print(f" Ξ: {v6_map - v5_map:+.4f} ({'V6 better' if v6_map > v5_map else 'V5-A better'})") |
|
|
| print(f"\nProbe 14 complete!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|