#!/usr/bin/env python3 """ Embed Benchmark Task Datasets with V5/V6 [CLS] Embeddings Extracts frozen [CLS] embeddings for GlycanML benchmark task datasets and produces t-SNE/UMAP visualizations colored by ground-truth labels. Comparable to GlycanGT Figure 3. Tasks: 1. Taxonomy (domain, kingdom) 2. Glycosylation type (N/O/free) 3. Immunogenicity (0/1) Usage: python embed_benchmark_tasks.py --model v5 [--splits val test] [--embed_all] python embed_benchmark_tasks.py --model v6 [--splits val test] [--embed_all] """ import argparse import json import os import sys import warnings from pathlib import Path import numpy as np import torch import torch.nn.functional as F import pandas as pd warnings.filterwarnings('ignore') PROJECT_ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4')) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from downstream_tasks.utils.tokenizer import WURCSTokenizer DATA_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data' VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json' CHECKPOINTS = { 'v5': PROJECT_ROOT / 'checkpoints_v5b_excluded' / 'best_v5b_excluded_model.pt', 'v6': PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'phase_3_hard_checkpoint.pt', } # Try alternate V6 locations _v6_alts = [ PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'best_model.pt', PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'checkpoint_latest.pt', PROJECT_ROOT / 'bert_v6_contrastive' / 'phase_3_hard_checkpoint.pt', ] for _alt in _v6_alts: if _alt.exists(): CHECKPOINTS['v6'] = _alt break TASKS = { 'domain': { 'csv': 'glycan_classification_wurcs_subset.csv', 'label_col': 'domain', 'wurcs_col': 'wurcs', 'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'}, 'description': 'Taxonomy domain (Eukarya/Bacteria/Virus/Archaea)', }, 'kingdom': { 'csv': 'glycan_classification_wurcs_subset.csv', 'label_col': 'kingdom', 'wurcs_col': 'wurcs', 'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'}, 'description': 'Taxonomy kingdom (11 classes)', }, 'link': { 'csv': 'glycan_link_wurcs_subset.csv', 'label_col': 'link', 'wurcs_col': 'wurcs', 'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'}, 'description': 'Glycosylation type (N-linked/O-linked/free)', }, 'immunogenicity': { 'csv': 'glycan_immunogenicity_wurcs_subset.csv', 'label_col': 'immunogenicity', 'wurcs_col': 'wurcs', 'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'}, 'description': 'Immunogenicity (0=non-immunogenic, 1=immunogenic)', }, } DOMAIN_COLORS = { 'Eukarya': '#2196F3', 'Bacteria': '#FF5722', 'Virus': '#9C27B0', 'Archaea': '#4CAF50' } KINGDOM_COLORS = { 'Plantae': '#4CAF50', 'Animalia': '#F44336', 'Fungi': '#FF9800', 'Protista': '#9C27B0', 'Viridiplantae': '#8BC34A', 'Metazoa': '#E91E63', } LINK_COLORS = {'N': '#2196F3', 'O': '#FF5722', 'free': '#4CAF50'} IMMUNO_COLORS = {0.0: '#607D8B', 1.0: '#F44336', '0.0': '#607D8B', '1.0': '#F44336'} def load_model(checkpoint_path, device='cuda'): print(f"Loading model from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} n_stripped = len(state_dict) - len(backbone_sd) if n_stripped > 0: print(f" Stripped {n_stripped} projection head keys") vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] ms_total_vocab = None if 'ms_embeddings.token_embeddings.weight' in backbone_sd: ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] config_kwargs = dict( seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, ) if ms_total_vocab is not None: config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size config = MultimodalGlycanBERTConfig(**config_kwargs) model = MultimodalGlycanBERT(config) model.load_state_dict(backbone_sd, strict=False) model.to(device) model.eval() print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params") return model def extract_cls_embeddings(model, tokenized_samples, device='cuda', batch_size=64, max_len=256): all_embeddings = [] n_failed = 0 for i in range(0, len(tokenized_samples), batch_size): batch = tokenized_samples[i:i + batch_size] batch_tids, batch_bdeps, batch_ltypes = [], [], [] for sample in batch: try: tids = sample['token_ids'] bdeps = sample.get('branch_depths', [0] * len(tids)) ltypes = sample.get('linkage_types', [0] * len(tids)) tids_t = torch.tensor(tids[:max_len], dtype=torch.long) bdeps_t = torch.tensor(bdeps[:max_len], dtype=torch.long) ltypes_t = torch.tensor(ltypes[:max_len], dtype=torch.long) min_len = min(len(tids_t), len(bdeps_t), len(ltypes_t)) tids_t, bdeps_t, ltypes_t = tids_t[:min_len], bdeps_t[:min_len], ltypes_t[:min_len] if len(tids_t) < max_len: pad_len = max_len - len(tids_t) tids_t = F.pad(tids_t, (0, pad_len), value=0) bdeps_t = F.pad(bdeps_t, (0, pad_len), value=0) ltypes_t = F.pad(ltypes_t, (0, pad_len), value=0) batch_tids.append(tids_t) batch_bdeps.append(bdeps_t) batch_ltypes.append(ltypes_t) except Exception: n_failed += 1 if not batch_tids: continue with torch.no_grad(): seq_out = model.seq_embeddings( torch.stack(batch_tids).to(device), branch_depths=torch.stack(batch_bdeps).to(device), linkage_types=torch.stack(batch_ltypes).to(device) ) all_embeddings.append(seq_out[:, 0, :].cpu().numpy()) if n_failed > 0: print(f" Warning: {n_failed} samples failed") return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.array([]) def load_task_data(task_name, tokenizer, splits=None, embed_all=False): task_cfg = TASKS[task_name] csv_path = DATA_DIR / task_cfg['csv'] label_col = task_cfg['label_col'] wurcs_col = task_cfg['wurcs_col'] split_cols = task_cfg['split_cols'] print(f"\n{'='*60}") print(f"Loading task: {task_name} ({task_cfg['description']})") print(f" CSV: {csv_path}") df = pd.read_csv(csv_path) print(f" Total rows: {len(df)}") target_splits = list(split_cols.keys()) if embed_all or splits is None else splits results = [] n_tokenized = n_failed = n_ambiguous = 0 for _, row in df.iterrows(): split = 'unknown' for split_name, col_name in split_cols.items(): if col_name in df.columns: val = row.get(col_name) if val == 1 or val == True or str(val).lower() in ('true', '1', '1.0'): split = split_name break if split not in target_splits and not embed_all: continue label = row.get(label_col, '') if pd.isna(label) or label == '' or label == 'nan': label = 'Unknown' wurcs = row.get(wurcs_col, '') if pd.isna(wurcs) or wurcs == '' or not str(wurcs).startswith('WURCS'): n_ambiguous += 1 continue try: tok = tokenizer.tokenize(str(wurcs), max_length=256) results.append({ 'token_ids': tok['token_ids'], 'branch_depths': tok.get('branch_depths', [0] * len(tok['token_ids'])), 'linkage_types': tok.get('linkage_types', [0] * len(tok['token_ids'])), 'label': str(label), 'split': split, 'wurcs': str(wurcs), }) n_tokenized += 1 except Exception: n_failed += 1 print(f" Tokenized: {n_tokenized}, Failed: {n_failed}, Ambiguous: {n_ambiguous}") for s in target_splits: s_data = [r for r in results if r['split'] == s] labels = {} for r in s_data: labels[r['label']] = labels.get(r['label'], 0) + 1 print(f" Split '{s}': {len(s_data)} samples, labels: {labels}") return results def plot_embeddings(embeddings, labels, task_name, model_name, output_dir, method='tsne', color_map=None, split_labels=None): import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.metrics import silhouette_score, calinski_harabasz_score print(f" Plotting {method.upper()} for {task_name} ({model_name})...") if method == 'tsne': from sklearn.manifold import TSNE perplexity = min(30, len(embeddings) - 1) coords = TSNE(n_components=2, perplexity=perplexity, max_iter=1000, init='pca', random_state=42, learning_rate='auto').fit_transform(embeddings) else: import umap coords = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(embeddings) unique_labels = sorted(set(labels)) label_to_int = {l: i for i, l in enumerate(unique_labels)} int_labels = np.array([label_to_int[l] for l in labels]) metrics = {} if 2 <= len(unique_labels) < len(embeddings): try: metrics['silhouette'] = float(silhouette_score(embeddings, int_labels)) except: metrics['silhouette'] = None try: metrics['calinski_harabasz'] = float(calinski_harabasz_score(embeddings, int_labels)) except: metrics['calinski_harabasz'] = None metrics['n_samples'] = len(embeddings) metrics['n_classes'] = len(unique_labels) metrics['classes'] = unique_labels fig, ax = plt.subplots(1, 1, figsize=(10, 8)) for label in unique_labels: mask = np.array(labels) == label color = color_map.get(label, None) if color_map else None ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=f'{label} (n={mask.sum()})', s=15, alpha=0.7, edgecolors='none') if split_labels is not None: for split in sorted(set(split_labels)): mask = np.array(split_labels) == split if split == 'test': ax.scatter(coords[mask, 0], coords[mask, 1], facecolors='none', edgecolors='black', s=40, linewidths=0.5, alpha=0.3, label=f'test split (n={mask.sum()})') sil_str = f"Sil={metrics.get('silhouette', 'N/A'):.3f}" if metrics.get('silhouette') is not None else "Sil=N/A" ch_str = f"CH={metrics.get('calinski_harabasz', 'N/A'):.1f}" if metrics.get('calinski_harabasz') is not None else "CH=N/A" ax.set_title(f"{task_name} - {model_name.upper()} [CLS] ({method.upper()})\n{sil_str} | {ch_str} | n={len(embeddings)}", fontsize=13) ax.set_xlabel(f'{method.upper()}-1') ax.set_ylabel(f'{method.upper()}-2') ax.legend(loc='best', fontsize=8, framealpha=0.8) ax.set_aspect('equal', adjustable='box') plt.tight_layout() fname = f'{task_name}_{model_name}_{method}.png' plt.savefig(os.path.join(output_dir, fname), dpi=200, bbox_inches='tight') plt.close() print(f" Saved: {fname}") return metrics def main(): parser = argparse.ArgumentParser(description='Embed benchmark tasks with V5/V6') parser.add_argument('--model', choices=['v5', 'v6'], required=True) parser.add_argument('--splits', nargs='+', default=['val', 'test']) parser.add_argument('--embed_all', action='store_true') parser.add_argument('--tasks', nargs='+', default=list(TASKS.keys())) parser.add_argument('--method', choices=['tsne', 'umap', 'both'], default='tsne') parser.add_argument('--output_dir', default=None) parser.add_argument('--device', default='cuda') args = parser.parse_args() if args.output_dir is None: args.output_dir = str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'benchmark_embeddings') os.makedirs(args.output_dir, exist_ok=True) print(f"Loading tokenizer from {VOCAB_PATH}...") tokenizer = WURCSTokenizer(str(VOCAB_PATH)) print(f" Vocab size: {tokenizer.vocab_size}") ckpt_path = CHECKPOINTS[args.model] if not ckpt_path.exists(): print(f"ERROR: Checkpoint not found: {ckpt_path}") sys.exit(1) model = load_model(str(ckpt_path), device=args.device) color_maps = {'domain': DOMAIN_COLORS, 'kingdom': KINGDOM_COLORS, 'link': LINK_COLORS, 'immunogenicity': IMMUNO_COLORS} all_metrics = {} for task_name in args.tasks: if task_name not in TASKS: print(f"WARNING: Unknown task '{task_name}', skipping") continue data = load_task_data(task_name, tokenizer, splits=args.splits if not args.embed_all else None, embed_all=args.embed_all) if len(data) < 10: print(f" Skipping {task_name}: too few samples ({len(data)})") continue print(f" Extracting [CLS] embeddings for {len(data)} samples...") embeddings = extract_cls_embeddings(model, data, device=args.device) labels = [d['label'] for d in data] split_labels = [d['split'] for d in data] valid_mask = [l != 'Unknown' for l in labels] embeddings = embeddings[valid_mask] labels = [l for l, v in zip(labels, valid_mask) if v] split_labels = [s for s, v in zip(split_labels, valid_mask) if v] if len(embeddings) < 10: print(f" Skipping {task_name}: too few labeled samples") continue print(f" Embeddings shape: {embeddings.shape}") npz_path = os.path.join(args.output_dir, f'{task_name}_{args.model}_embeddings.npz') np.savez_compressed(npz_path, embeddings=embeddings, labels=np.array(labels), splits=np.array(split_labels)) print(f" Saved: {npz_path}") methods = ['tsne', 'umap'] if args.method == 'both' else [args.method] task_metrics = {} for method in methods: m = plot_embeddings(embeddings, labels, task_name, args.model, args.output_dir, method=method, color_map=color_maps.get(task_name, None), split_labels=split_labels) task_metrics[method] = m all_metrics[task_name] = task_metrics metrics_path = os.path.join(args.output_dir, f'benchmark_metrics_{args.model}.json') with open(metrics_path, 'w') as f: json.dump(all_metrics, f, indent=2, default=str) print(f"\nAll metrics saved to: {metrics_path}") print(f"\n{'='*60}") print(f"SUMMARY - {args.model.upper()}") print(f"{'='*60}") for task, tmetrics in all_metrics.items(): for method, m in tmetrics.items(): sil = m.get('silhouette', 'N/A') sil_str = f"{sil:.4f}" if isinstance(sil, float) else str(sil) print(f" {task:20s} ({method:5s}): Silhouette={sil_str}, n={m.get('n_samples',0)}, classes={m.get('n_classes',0)}") if __name__ == '__main__': main()