| |
| """ |
| Embed Benchmark Task Datasets with V5/V6 [CLS] Embeddings |
| |
| Extracts frozen [CLS] embeddings for GlycanML benchmark task datasets |
| and produces t-SNE/UMAP visualizations colored by ground-truth labels. |
| |
| Comparable to GlycanGT Figure 3. |
| |
| Tasks: |
| 1. Taxonomy (domain, kingdom) |
| 2. Glycosylation type (N/O/free) |
| 3. Immunogenicity (0/1) |
| |
| Usage: |
| python embed_benchmark_tasks.py --model v5 [--splits val test] [--embed_all] |
| python embed_benchmark_tasks.py --model v6 [--splits val test] [--embed_all] |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import warnings |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import pandas as pd |
|
|
| warnings.filterwarnings('ignore') |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| sys.path.insert(0, str(PROJECT_ROOT)) |
| sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from downstream_tasks.utils.tokenizer import WURCSTokenizer |
|
|
| DATA_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data' |
| VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' |
|
|
| CHECKPOINTS = { |
| 'v5': PROJECT_ROOT / 'checkpoints_v5b_excluded' / 'best_v5b_excluded_model.pt', |
| 'v6': PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'phase_3_hard_checkpoint.pt', |
| } |
|
|
| |
| _v6_alts = [ |
| PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'best_model.pt', |
| PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'checkpoint_latest.pt', |
| PROJECT_ROOT / 'bert_v6_contrastive' / 'phase_3_hard_checkpoint.pt', |
| ] |
| for _alt in _v6_alts: |
| if _alt.exists(): |
| CHECKPOINTS['v6'] = _alt |
| break |
|
|
| TASKS = { |
| 'domain': { |
| 'csv': 'glycan_classification_wurcs_subset.csv', |
| 'label_col': 'domain', |
| 'wurcs_col': 'wurcs', |
| 'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'}, |
| 'description': 'Taxonomy domain (Eukarya/Bacteria/Virus/Archaea)', |
| }, |
| 'kingdom': { |
| 'csv': 'glycan_classification_wurcs_subset.csv', |
| 'label_col': 'kingdom', |
| 'wurcs_col': 'wurcs', |
| 'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'}, |
| 'description': 'Taxonomy kingdom (11 classes)', |
| }, |
| 'link': { |
| 'csv': 'glycan_link_wurcs_subset.csv', |
| 'label_col': 'link', |
| 'wurcs_col': 'wurcs', |
| 'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'}, |
| 'description': 'Glycosylation type (N-linked/O-linked/free)', |
| }, |
| 'immunogenicity': { |
| 'csv': 'glycan_immunogenicity_wurcs_subset.csv', |
| 'label_col': 'immunogenicity', |
| 'wurcs_col': 'wurcs', |
| 'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'}, |
| 'description': 'Immunogenicity (0=non-immunogenic, 1=immunogenic)', |
| }, |
| } |
|
|
| DOMAIN_COLORS = { |
| 'Eukarya': '#2196F3', 'Bacteria': '#FF5722', 'Virus': '#9C27B0', 'Archaea': '#4CAF50' |
| } |
| KINGDOM_COLORS = { |
| 'Plantae': '#4CAF50', 'Animalia': '#F44336', 'Fungi': '#FF9800', |
| 'Protista': '#9C27B0', 'Viridiplantae': '#8BC34A', 'Metazoa': '#E91E63', |
| } |
| LINK_COLORS = {'N': '#2196F3', 'O': '#FF5722', 'free': '#4CAF50'} |
| IMMUNO_COLORS = {0.0: '#607D8B', 1.0: '#F44336', '0.0': '#607D8B', '1.0': '#F44336'} |
|
|
|
|
| def load_model(checkpoint_path, device='cuda'): |
| print(f"Loading model from {checkpoint_path}...") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| if 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
| backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} |
| n_stripped = len(state_dict) - len(backbone_sd) |
| if n_stripped > 0: |
| print(f" Stripped {n_stripped} projection head keys") |
| vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] |
| ms_total_vocab = None |
| if 'ms_embeddings.token_embeddings.weight' in backbone_sd: |
| ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] |
| config_kwargs = dict( |
| seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, |
| seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, |
| ) |
| if ms_total_vocab is not None: |
| config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size |
| config = MultimodalGlycanBERTConfig(**config_kwargs) |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(backbone_sd, strict=False) |
| model.to(device) |
| model.eval() |
| print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") |
| return model |
|
|
|
|
| def extract_cls_embeddings(model, tokenized_samples, device='cuda', batch_size=64, max_len=256): |
| all_embeddings = [] |
| n_failed = 0 |
| for i in range(0, len(tokenized_samples), batch_size): |
| batch = tokenized_samples[i:i + batch_size] |
| batch_tids, batch_bdeps, batch_ltypes = [], [], [] |
| for sample in batch: |
| try: |
| tids = sample['token_ids'] |
| bdeps = sample.get('branch_depths', [0] * len(tids)) |
| ltypes = sample.get('linkage_types', [0] * len(tids)) |
| tids_t = torch.tensor(tids[:max_len], dtype=torch.long) |
| bdeps_t = torch.tensor(bdeps[:max_len], dtype=torch.long) |
| ltypes_t = torch.tensor(ltypes[:max_len], dtype=torch.long) |
| min_len = min(len(tids_t), len(bdeps_t), len(ltypes_t)) |
| tids_t, bdeps_t, ltypes_t = tids_t[:min_len], bdeps_t[:min_len], ltypes_t[:min_len] |
| if len(tids_t) < max_len: |
| pad_len = max_len - len(tids_t) |
| tids_t = F.pad(tids_t, (0, pad_len), value=0) |
| bdeps_t = F.pad(bdeps_t, (0, pad_len), value=0) |
| ltypes_t = F.pad(ltypes_t, (0, pad_len), value=0) |
| batch_tids.append(tids_t) |
| batch_bdeps.append(bdeps_t) |
| batch_ltypes.append(ltypes_t) |
| except Exception: |
| n_failed += 1 |
| if not batch_tids: |
| continue |
| with torch.no_grad(): |
| seq_out = model.seq_embeddings( |
| torch.stack(batch_tids).to(device), |
| branch_depths=torch.stack(batch_bdeps).to(device), |
| linkage_types=torch.stack(batch_ltypes).to(device) |
| ) |
| all_embeddings.append(seq_out[:, 0, :].cpu().numpy()) |
| if n_failed > 0: |
| print(f" Warning: {n_failed} samples failed") |
| return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.array([]) |
|
|
|
|
| def load_task_data(task_name, tokenizer, splits=None, embed_all=False): |
| task_cfg = TASKS[task_name] |
| csv_path = DATA_DIR / task_cfg['csv'] |
| label_col = task_cfg['label_col'] |
| wurcs_col = task_cfg['wurcs_col'] |
| split_cols = task_cfg['split_cols'] |
| print(f"\n{'='*60}") |
| print(f"Loading task: {task_name} ({task_cfg['description']})") |
| print(f" CSV: {csv_path}") |
| df = pd.read_csv(csv_path) |
| print(f" Total rows: {len(df)}") |
| target_splits = list(split_cols.keys()) if embed_all or splits is None else splits |
| results = [] |
| n_tokenized = n_failed = n_ambiguous = 0 |
| for _, row in df.iterrows(): |
| split = 'unknown' |
| for split_name, col_name in split_cols.items(): |
| if col_name in df.columns: |
| val = row.get(col_name) |
| if val == 1 or val == True or str(val).lower() in ('true', '1', '1.0'): |
| split = split_name |
| break |
| if split not in target_splits and not embed_all: |
| continue |
| label = row.get(label_col, '') |
| if pd.isna(label) or label == '' or label == 'nan': |
| label = 'Unknown' |
| wurcs = row.get(wurcs_col, '') |
| if pd.isna(wurcs) or wurcs == '' or not str(wurcs).startswith('WURCS'): |
| n_ambiguous += 1 |
| continue |
| try: |
| tok = tokenizer.tokenize(str(wurcs), max_length=256) |
| results.append({ |
| 'token_ids': tok['token_ids'], |
| 'branch_depths': tok.get('branch_depths', [0] * len(tok['token_ids'])), |
| 'linkage_types': tok.get('linkage_types', [0] * len(tok['token_ids'])), |
| 'label': str(label), 'split': split, 'wurcs': str(wurcs), |
| }) |
| n_tokenized += 1 |
| except Exception: |
| n_failed += 1 |
| print(f" Tokenized: {n_tokenized}, Failed: {n_failed}, Ambiguous: {n_ambiguous}") |
| for s in target_splits: |
| s_data = [r for r in results if r['split'] == s] |
| labels = {} |
| for r in s_data: |
| labels[r['label']] = labels.get(r['label'], 0) + 1 |
| print(f" Split '{s}': {len(s_data)} samples, labels: {labels}") |
| return results |
|
|
|
|
| def plot_embeddings(embeddings, labels, task_name, model_name, output_dir, method='tsne', |
| color_map=None, split_labels=None): |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import silhouette_score, calinski_harabasz_score |
| print(f" Plotting {method.upper()} for {task_name} ({model_name})...") |
| if method == 'tsne': |
| from sklearn.manifold import TSNE |
| perplexity = min(30, len(embeddings) - 1) |
| coords = TSNE(n_components=2, perplexity=perplexity, max_iter=1000, |
| init='pca', random_state=42, learning_rate='auto').fit_transform(embeddings) |
| else: |
| import umap |
| coords = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(embeddings) |
| unique_labels = sorted(set(labels)) |
| label_to_int = {l: i for i, l in enumerate(unique_labels)} |
| int_labels = np.array([label_to_int[l] for l in labels]) |
| metrics = {} |
| if 2 <= len(unique_labels) < len(embeddings): |
| try: metrics['silhouette'] = float(silhouette_score(embeddings, int_labels)) |
| except: metrics['silhouette'] = None |
| try: metrics['calinski_harabasz'] = float(calinski_harabasz_score(embeddings, int_labels)) |
| except: metrics['calinski_harabasz'] = None |
| metrics['n_samples'] = len(embeddings) |
| metrics['n_classes'] = len(unique_labels) |
| metrics['classes'] = unique_labels |
| fig, ax = plt.subplots(1, 1, figsize=(10, 8)) |
| for label in unique_labels: |
| mask = np.array(labels) == label |
| color = color_map.get(label, None) if color_map else None |
| ax.scatter(coords[mask, 0], coords[mask, 1], c=color, |
| label=f'{label} (n={mask.sum()})', s=15, alpha=0.7, edgecolors='none') |
| if split_labels is not None: |
| for split in sorted(set(split_labels)): |
| mask = np.array(split_labels) == split |
| if split == 'test': |
| ax.scatter(coords[mask, 0], coords[mask, 1], facecolors='none', |
| edgecolors='black', s=40, linewidths=0.5, alpha=0.3, |
| label=f'test split (n={mask.sum()})') |
| sil_str = f"Sil={metrics.get('silhouette', 'N/A'):.3f}" if metrics.get('silhouette') is not None else "Sil=N/A" |
| ch_str = f"CH={metrics.get('calinski_harabasz', 'N/A'):.1f}" if metrics.get('calinski_harabasz') is not None else "CH=N/A" |
| ax.set_title(f"{task_name} - {model_name.upper()} [CLS] ({method.upper()})\n{sil_str} | {ch_str} | n={len(embeddings)}", fontsize=13) |
| ax.set_xlabel(f'{method.upper()}-1') |
| ax.set_ylabel(f'{method.upper()}-2') |
| ax.legend(loc='best', fontsize=8, framealpha=0.8) |
| ax.set_aspect('equal', adjustable='box') |
| plt.tight_layout() |
| fname = f'{task_name}_{model_name}_{method}.png' |
| plt.savefig(os.path.join(output_dir, fname), dpi=200, bbox_inches='tight') |
| plt.close() |
| print(f" Saved: {fname}") |
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Embed benchmark tasks with V5/V6') |
| parser.add_argument('--model', choices=['v5', 'v6'], required=True) |
| parser.add_argument('--splits', nargs='+', default=['val', 'test']) |
| parser.add_argument('--embed_all', action='store_true') |
| parser.add_argument('--tasks', nargs='+', default=list(TASKS.keys())) |
| parser.add_argument('--method', choices=['tsne', 'umap', 'both'], default='tsne') |
| parser.add_argument('--output_dir', default=None) |
| parser.add_argument('--device', default='cuda') |
| args = parser.parse_args() |
| if args.output_dir is None: |
| args.output_dir = str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'benchmark_embeddings') |
| os.makedirs(args.output_dir, exist_ok=True) |
| print(f"Loading tokenizer from {VOCAB_PATH}...") |
| tokenizer = WURCSTokenizer(str(VOCAB_PATH)) |
| print(f" Vocab size: {tokenizer.vocab_size}") |
| ckpt_path = CHECKPOINTS[args.model] |
| if not ckpt_path.exists(): |
| print(f"ERROR: Checkpoint not found: {ckpt_path}") |
| sys.exit(1) |
| model = load_model(str(ckpt_path), device=args.device) |
| color_maps = {'domain': DOMAIN_COLORS, 'kingdom': KINGDOM_COLORS, |
| 'link': LINK_COLORS, 'immunogenicity': IMMUNO_COLORS} |
| all_metrics = {} |
| for task_name in args.tasks: |
| if task_name not in TASKS: |
| print(f"WARNING: Unknown task '{task_name}', skipping") |
| continue |
| data = load_task_data(task_name, tokenizer, |
| splits=args.splits if not args.embed_all else None, |
| embed_all=args.embed_all) |
| if len(data) < 10: |
| print(f" Skipping {task_name}: too few samples ({len(data)})") |
| continue |
| print(f" Extracting [CLS] embeddings for {len(data)} samples...") |
| embeddings = extract_cls_embeddings(model, data, device=args.device) |
| labels = [d['label'] for d in data] |
| split_labels = [d['split'] for d in data] |
| valid_mask = [l != 'Unknown' for l in labels] |
| embeddings = embeddings[valid_mask] |
| labels = [l for l, v in zip(labels, valid_mask) if v] |
| split_labels = [s for s, v in zip(split_labels, valid_mask) if v] |
| if len(embeddings) < 10: |
| print(f" Skipping {task_name}: too few labeled samples") |
| continue |
| print(f" Embeddings shape: {embeddings.shape}") |
| npz_path = os.path.join(args.output_dir, f'{task_name}_{args.model}_embeddings.npz') |
| np.savez_compressed(npz_path, embeddings=embeddings, |
| labels=np.array(labels), splits=np.array(split_labels)) |
| print(f" Saved: {npz_path}") |
| methods = ['tsne', 'umap'] if args.method == 'both' else [args.method] |
| task_metrics = {} |
| for method in methods: |
| m = plot_embeddings(embeddings, labels, task_name, args.model, |
| args.output_dir, method=method, |
| color_map=color_maps.get(task_name, None), |
| split_labels=split_labels) |
| task_metrics[method] = m |
| all_metrics[task_name] = task_metrics |
| metrics_path = os.path.join(args.output_dir, f'benchmark_metrics_{args.model}.json') |
| with open(metrics_path, 'w') as f: |
| json.dump(all_metrics, f, indent=2, default=str) |
| print(f"\nAll metrics saved to: {metrics_path}") |
| print(f"\n{'='*60}") |
| print(f"SUMMARY - {args.model.upper()}") |
| print(f"{'='*60}") |
| for task, tmetrics in all_metrics.items(): |
| for method, m in tmetrics.items(): |
| sil = m.get('silhouette', 'N/A') |
| sil_str = f"{sil:.4f}" if isinstance(sil, float) else str(sil) |
| print(f" {task:20s} ({method:5s}): Silhouette={sil_str}, n={m.get('n_samples',0)}, classes={m.get('n_classes',0)}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|