bertose-affinose-training-code / code /probes /embed_benchmark_tasks.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
15.8 kB
#!/usr/bin/env python3
"""
Embed Benchmark Task Datasets with V5/V6 [CLS] Embeddings
Extracts frozen [CLS] embeddings for GlycanML benchmark task datasets
and produces t-SNE/UMAP visualizations colored by ground-truth labels.
Comparable to GlycanGT Figure 3.
Tasks:
1. Taxonomy (domain, kingdom)
2. Glycosylation type (N/O/free)
3. Immunogenicity (0/1)
Usage:
python embed_benchmark_tasks.py --model v5 [--splits val test] [--embed_all]
python embed_benchmark_tasks.py --model v6 [--splits val test] [--embed_all]
"""
import argparse
import json
import os
import sys
import warnings
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
warnings.filterwarnings('ignore')
PROJECT_ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
from downstream_tasks.utils.tokenizer import WURCSTokenizer
DATA_DIR = PROJECT_ROOT / 'bench' / 'GlycanML' / 'data'
VOCAB_PATH = PROJECT_ROOT / 'bert_training_v4' / 'data' / 'vocabulary.json'
CHECKPOINTS = {
'v5': PROJECT_ROOT / 'checkpoints_v5b_excluded' / 'best_v5b_excluded_model.pt',
'v6': PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'phase_3_hard_checkpoint.pt',
}
# Try alternate V6 locations
_v6_alts = [
PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'best_model.pt',
PROJECT_ROOT / 'bert_v6_contrastive' / 'checkpoints' / 'checkpoint_latest.pt',
PROJECT_ROOT / 'bert_v6_contrastive' / 'phase_3_hard_checkpoint.pt',
]
for _alt in _v6_alts:
if _alt.exists():
CHECKPOINTS['v6'] = _alt
break
TASKS = {
'domain': {
'csv': 'glycan_classification_wurcs_subset.csv',
'label_col': 'domain',
'wurcs_col': 'wurcs',
'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'},
'description': 'Taxonomy domain (Eukarya/Bacteria/Virus/Archaea)',
},
'kingdom': {
'csv': 'glycan_classification_wurcs_subset.csv',
'label_col': 'kingdom',
'wurcs_col': 'wurcs',
'split_cols': {'train': 'train', 'val': 'validation', 'test': 'test'},
'description': 'Taxonomy kingdom (11 classes)',
},
'link': {
'csv': 'glycan_link_wurcs_subset.csv',
'label_col': 'link',
'wurcs_col': 'wurcs',
'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'},
'description': 'Glycosylation type (N-linked/O-linked/free)',
},
'immunogenicity': {
'csv': 'glycan_immunogenicity_wurcs_subset.csv',
'label_col': 'immunogenicity',
'wurcs_col': 'wurcs',
'split_cols': {'train': 'train', 'val': 'valid', 'test': 'test'},
'description': 'Immunogenicity (0=non-immunogenic, 1=immunogenic)',
},
}
DOMAIN_COLORS = {
'Eukarya': '#2196F3', 'Bacteria': '#FF5722', 'Virus': '#9C27B0', 'Archaea': '#4CAF50'
}
KINGDOM_COLORS = {
'Plantae': '#4CAF50', 'Animalia': '#F44336', 'Fungi': '#FF9800',
'Protista': '#9C27B0', 'Viridiplantae': '#8BC34A', 'Metazoa': '#E91E63',
}
LINK_COLORS = {'N': '#2196F3', 'O': '#FF5722', 'free': '#4CAF50'}
IMMUNO_COLORS = {0.0: '#607D8B', 1.0: '#F44336', '0.0': '#607D8B', '1.0': '#F44336'}
def load_model(checkpoint_path, device='cuda'):
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')}
n_stripped = len(state_dict) - len(backbone_sd)
if n_stripped > 0:
print(f" Stripped {n_stripped} projection head keys")
vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0]
ms_total_vocab = None
if 'ms_embeddings.token_embeddings.weight' in backbone_sd:
ms_total_vocab = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0]
config_kwargs = dict(
seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12,
seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3,
)
if ms_total_vocab is not None:
config_kwargs['ms_vocab_size'] = ms_total_vocab - vocab_size
config = MultimodalGlycanBERTConfig(**config_kwargs)
model = MultimodalGlycanBERT(config)
model.load_state_dict(backbone_sd, strict=False)
model.to(device)
model.eval()
print(f" Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
return model
def extract_cls_embeddings(model, tokenized_samples, device='cuda', batch_size=64, max_len=256):
all_embeddings = []
n_failed = 0
for i in range(0, len(tokenized_samples), batch_size):
batch = tokenized_samples[i:i + batch_size]
batch_tids, batch_bdeps, batch_ltypes = [], [], []
for sample in batch:
try:
tids = sample['token_ids']
bdeps = sample.get('branch_depths', [0] * len(tids))
ltypes = sample.get('linkage_types', [0] * len(tids))
tids_t = torch.tensor(tids[:max_len], dtype=torch.long)
bdeps_t = torch.tensor(bdeps[:max_len], dtype=torch.long)
ltypes_t = torch.tensor(ltypes[:max_len], dtype=torch.long)
min_len = min(len(tids_t), len(bdeps_t), len(ltypes_t))
tids_t, bdeps_t, ltypes_t = tids_t[:min_len], bdeps_t[:min_len], ltypes_t[:min_len]
if len(tids_t) < max_len:
pad_len = max_len - len(tids_t)
tids_t = F.pad(tids_t, (0, pad_len), value=0)
bdeps_t = F.pad(bdeps_t, (0, pad_len), value=0)
ltypes_t = F.pad(ltypes_t, (0, pad_len), value=0)
batch_tids.append(tids_t)
batch_bdeps.append(bdeps_t)
batch_ltypes.append(ltypes_t)
except Exception:
n_failed += 1
if not batch_tids:
continue
with torch.no_grad():
seq_out = model.seq_embeddings(
torch.stack(batch_tids).to(device),
branch_depths=torch.stack(batch_bdeps).to(device),
linkage_types=torch.stack(batch_ltypes).to(device)
)
all_embeddings.append(seq_out[:, 0, :].cpu().numpy())
if n_failed > 0:
print(f" Warning: {n_failed} samples failed")
return np.concatenate(all_embeddings, axis=0) if all_embeddings else np.array([])
def load_task_data(task_name, tokenizer, splits=None, embed_all=False):
task_cfg = TASKS[task_name]
csv_path = DATA_DIR / task_cfg['csv']
label_col = task_cfg['label_col']
wurcs_col = task_cfg['wurcs_col']
split_cols = task_cfg['split_cols']
print(f"\n{'='*60}")
print(f"Loading task: {task_name} ({task_cfg['description']})")
print(f" CSV: {csv_path}")
df = pd.read_csv(csv_path)
print(f" Total rows: {len(df)}")
target_splits = list(split_cols.keys()) if embed_all or splits is None else splits
results = []
n_tokenized = n_failed = n_ambiguous = 0
for _, row in df.iterrows():
split = 'unknown'
for split_name, col_name in split_cols.items():
if col_name in df.columns:
val = row.get(col_name)
if val == 1 or val == True or str(val).lower() in ('true', '1', '1.0'):
split = split_name
break
if split not in target_splits and not embed_all:
continue
label = row.get(label_col, '')
if pd.isna(label) or label == '' or label == 'nan':
label = 'Unknown'
wurcs = row.get(wurcs_col, '')
if pd.isna(wurcs) or wurcs == '' or not str(wurcs).startswith('WURCS'):
n_ambiguous += 1
continue
try:
tok = tokenizer.tokenize(str(wurcs), max_length=256)
results.append({
'token_ids': tok['token_ids'],
'branch_depths': tok.get('branch_depths', [0] * len(tok['token_ids'])),
'linkage_types': tok.get('linkage_types', [0] * len(tok['token_ids'])),
'label': str(label), 'split': split, 'wurcs': str(wurcs),
})
n_tokenized += 1
except Exception:
n_failed += 1
print(f" Tokenized: {n_tokenized}, Failed: {n_failed}, Ambiguous: {n_ambiguous}")
for s in target_splits:
s_data = [r for r in results if r['split'] == s]
labels = {}
for r in s_data:
labels[r['label']] = labels.get(r['label'], 0) + 1
print(f" Split '{s}': {len(s_data)} samples, labels: {labels}")
return results
def plot_embeddings(embeddings, labels, task_name, model_name, output_dir, method='tsne',
color_map=None, split_labels=None):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score, calinski_harabasz_score
print(f" Plotting {method.upper()} for {task_name} ({model_name})...")
if method == 'tsne':
from sklearn.manifold import TSNE
perplexity = min(30, len(embeddings) - 1)
coords = TSNE(n_components=2, perplexity=perplexity, max_iter=1000,
init='pca', random_state=42, learning_rate='auto').fit_transform(embeddings)
else:
import umap
coords = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(embeddings)
unique_labels = sorted(set(labels))
label_to_int = {l: i for i, l in enumerate(unique_labels)}
int_labels = np.array([label_to_int[l] for l in labels])
metrics = {}
if 2 <= len(unique_labels) < len(embeddings):
try: metrics['silhouette'] = float(silhouette_score(embeddings, int_labels))
except: metrics['silhouette'] = None
try: metrics['calinski_harabasz'] = float(calinski_harabasz_score(embeddings, int_labels))
except: metrics['calinski_harabasz'] = None
metrics['n_samples'] = len(embeddings)
metrics['n_classes'] = len(unique_labels)
metrics['classes'] = unique_labels
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
for label in unique_labels:
mask = np.array(labels) == label
color = color_map.get(label, None) if color_map else None
ax.scatter(coords[mask, 0], coords[mask, 1], c=color,
label=f'{label} (n={mask.sum()})', s=15, alpha=0.7, edgecolors='none')
if split_labels is not None:
for split in sorted(set(split_labels)):
mask = np.array(split_labels) == split
if split == 'test':
ax.scatter(coords[mask, 0], coords[mask, 1], facecolors='none',
edgecolors='black', s=40, linewidths=0.5, alpha=0.3,
label=f'test split (n={mask.sum()})')
sil_str = f"Sil={metrics.get('silhouette', 'N/A'):.3f}" if metrics.get('silhouette') is not None else "Sil=N/A"
ch_str = f"CH={metrics.get('calinski_harabasz', 'N/A'):.1f}" if metrics.get('calinski_harabasz') is not None else "CH=N/A"
ax.set_title(f"{task_name} - {model_name.upper()} [CLS] ({method.upper()})\n{sil_str} | {ch_str} | n={len(embeddings)}", fontsize=13)
ax.set_xlabel(f'{method.upper()}-1')
ax.set_ylabel(f'{method.upper()}-2')
ax.legend(loc='best', fontsize=8, framealpha=0.8)
ax.set_aspect('equal', adjustable='box')
plt.tight_layout()
fname = f'{task_name}_{model_name}_{method}.png'
plt.savefig(os.path.join(output_dir, fname), dpi=200, bbox_inches='tight')
plt.close()
print(f" Saved: {fname}")
return metrics
def main():
parser = argparse.ArgumentParser(description='Embed benchmark tasks with V5/V6')
parser.add_argument('--model', choices=['v5', 'v6'], required=True)
parser.add_argument('--splits', nargs='+', default=['val', 'test'])
parser.add_argument('--embed_all', action='store_true')
parser.add_argument('--tasks', nargs='+', default=list(TASKS.keys()))
parser.add_argument('--method', choices=['tsne', 'umap', 'both'], default='tsne')
parser.add_argument('--output_dir', default=None)
parser.add_argument('--device', default='cuda')
args = parser.parse_args()
if args.output_dir is None:
args.output_dir = str(PROJECT_ROOT / 'bert_v6_contrastive' / 'analysis' / 'benchmark_embeddings')
os.makedirs(args.output_dir, exist_ok=True)
print(f"Loading tokenizer from {VOCAB_PATH}...")
tokenizer = WURCSTokenizer(str(VOCAB_PATH))
print(f" Vocab size: {tokenizer.vocab_size}")
ckpt_path = CHECKPOINTS[args.model]
if not ckpt_path.exists():
print(f"ERROR: Checkpoint not found: {ckpt_path}")
sys.exit(1)
model = load_model(str(ckpt_path), device=args.device)
color_maps = {'domain': DOMAIN_COLORS, 'kingdom': KINGDOM_COLORS,
'link': LINK_COLORS, 'immunogenicity': IMMUNO_COLORS}
all_metrics = {}
for task_name in args.tasks:
if task_name not in TASKS:
print(f"WARNING: Unknown task '{task_name}', skipping")
continue
data = load_task_data(task_name, tokenizer,
splits=args.splits if not args.embed_all else None,
embed_all=args.embed_all)
if len(data) < 10:
print(f" Skipping {task_name}: too few samples ({len(data)})")
continue
print(f" Extracting [CLS] embeddings for {len(data)} samples...")
embeddings = extract_cls_embeddings(model, data, device=args.device)
labels = [d['label'] for d in data]
split_labels = [d['split'] for d in data]
valid_mask = [l != 'Unknown' for l in labels]
embeddings = embeddings[valid_mask]
labels = [l for l, v in zip(labels, valid_mask) if v]
split_labels = [s for s, v in zip(split_labels, valid_mask) if v]
if len(embeddings) < 10:
print(f" Skipping {task_name}: too few labeled samples")
continue
print(f" Embeddings shape: {embeddings.shape}")
npz_path = os.path.join(args.output_dir, f'{task_name}_{args.model}_embeddings.npz')
np.savez_compressed(npz_path, embeddings=embeddings,
labels=np.array(labels), splits=np.array(split_labels))
print(f" Saved: {npz_path}")
methods = ['tsne', 'umap'] if args.method == 'both' else [args.method]
task_metrics = {}
for method in methods:
m = plot_embeddings(embeddings, labels, task_name, args.model,
args.output_dir, method=method,
color_map=color_maps.get(task_name, None),
split_labels=split_labels)
task_metrics[method] = m
all_metrics[task_name] = task_metrics
metrics_path = os.path.join(args.output_dir, f'benchmark_metrics_{args.model}.json')
with open(metrics_path, 'w') as f:
json.dump(all_metrics, f, indent=2, default=str)
print(f"\nAll metrics saved to: {metrics_path}")
print(f"\n{'='*60}")
print(f"SUMMARY - {args.model.upper()}")
print(f"{'='*60}")
for task, tmetrics in all_metrics.items():
for method, m in tmetrics.items():
sil = m.get('silhouette', 'N/A')
sil_str = f"{sil:.4f}" if isinstance(sil, float) else str(sil)
print(f" {task:20s} ({method:5s}): Silhouette={sil_str}, n={m.get('n_samples',0)}, classes={m.get('n_classes',0)}")
if __name__ == '__main__':
main()