#!/usr/bin/env python3 """ Extract [CLS] Embeddings for Embedding Space Deep Dive Extracts embeddings from V5 and V6 checkpoints for multiple data subsets: - Training positives (sampled) - Impossible negatives (easy/medium/hard) - Benchmark test glycans (with taxonomy labels) Output: .npz files with embeddings + metadata for analysis. """ import os, sys, torch, pickle, json, csv, argparse import torch.nn.functional as F import numpy as np from pathlib import Path from tqdm import tqdm project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training') sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'bert_training_v4')) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig def load_model(checkpoint_path, device='cuda'): print(f"Loading model from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) state_dict = checkpoint.get('model_state_dict', checkpoint) # Strip projection head keys (V6) backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} n_stripped = len(state_dict) - len(backbone_sd) if n_stripped > 0: print(f" Stripped {n_stripped} projection head keys") vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] config_kwargs = dict( seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, ) if 'ms_embeddings.token_embeddings.weight' in backbone_sd: ms_total = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] config_kwargs['ms_vocab_size'] = ms_total - vocab_size config = MultimodalGlycanBERTConfig(**config_kwargs) model = MultimodalGlycanBERT(config) model.load_state_dict(backbone_sd, strict=False) model.to(device).eval() print(f" Model: {sum(p.numel() for p in model.parameters()):,} params") return model def get_cls_embedding(model, token_ids, branch_depths, linkage_types, device='cuda', max_len=256): with torch.no_grad(): if not isinstance(token_ids, torch.Tensor): token_ids = torch.tensor(token_ids, dtype=torch.long) if not isinstance(branch_depths, torch.Tensor): branch_depths = torch.tensor(branch_depths, dtype=torch.long) if not isinstance(linkage_types, torch.Tensor): linkage_types = torch.tensor(linkage_types, dtype=torch.long) token_ids, branch_depths, linkage_types = token_ids.flatten(), branch_depths.flatten(), linkage_types.flatten() min_len = min(len(token_ids), len(branch_depths), len(linkage_types)) token_ids, branch_depths, linkage_types = token_ids[:min_len], branch_depths[:min_len], linkage_types[:min_len] if min_len > max_len: token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len] if len(token_ids) < max_len: pad_len = max_len - len(token_ids) token_ids = F.pad(token_ids, (0, pad_len), value=0) branch_depths = F.pad(branch_depths, (0, pad_len), value=0) linkage_types = F.pad(linkage_types, (0, pad_len), value=0) token_ids = token_ids.unsqueeze(0).to(device) branch_depths = branch_depths.unsqueeze(0).to(device) linkage_types = linkage_types.unsqueeze(0).to(device) seq_out = model.seq_embeddings(token_ids, branch_depths=branch_depths, linkage_types=linkage_types) return seq_out[:, 0, :].cpu().numpy().flatten() def extract_batch(model, samples, device='cuda'): all_embs = [] for i, sample in enumerate(tqdm(samples, desc="Extracting")): token_ids = sample.get('token_ids', sample.get('tokens', [])) if isinstance(token_ids, str): token_ids = eval(token_ids) branch_depths = sample.get('branch_depths', [0] * len(token_ids)) if isinstance(branch_depths, str): branch_depths = eval(branch_depths) linkage_types = sample.get('linkage_types', [0] * len(token_ids)) if isinstance(linkage_types, str): linkage_types = eval(linkage_types) try: emb = get_cls_embedding(model, token_ids, branch_depths, linkage_types, device=device) all_embs.append(emb) except Exception as e: print(f" Error sample {i}: {e}") all_embs.append(np.zeros(768)) return np.array(all_embs) def load_benchmark_data(csv_path): print(f"Loading benchmark from {csv_path}...") iupac_list, labels = [], {'kingdom': [], 'phylum': [], 'class': [], 'split': []} with open(csv_path, 'r') as f: for row in csv.DictReader(f): iupac_list.append(row['target']) labels['kingdom'].append(row.get('kingdom', '')) labels['phylum'].append(row.get('phylum', '')) labels['class'].append(row.get('class', '')) if row.get('train', '').lower() == 'true': labels['split'].append('train') elif row.get('validation', '').lower() == 'true': labels['split'].append('val') elif row.get('test', '').lower() == 'true': labels['split'].append('test') else: labels['split'].append('unknown') print(f" {len(iupac_list)} samples") return iupac_list, labels def iupac_to_tokenized(iupac_list, sequences_data): lookup = {s.get('iupac_name', ''): s for s in sequences_data if s.get('iupac_name')} matched, indices = [], [] for idx, iupac in enumerate(iupac_list): if iupac in lookup: matched.append(lookup[iupac]) indices.append(idx) print(f" Matched {len(matched)}/{len(iupac_list)} IUPAC strings") return matched, indices def main(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', required=True) parser.add_argument('--name', required=True, help='v5 or v6') parser.add_argument('--sequences', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl') parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_scored.pkl') parser.add_argument('--benchmark_csv', default='bert_training_v4/downstream_tasks/baseline_data_strict/glycanml/glycan_classification.csv') parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis') parser.add_argument('--n_train_sample', type=int, default=10000) parser.add_argument('--n_neg_sample', type=int, default=5000) parser.add_argument('--device', default='cuda') args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) model = load_model(args.checkpoint, device=args.device) # 1. Training positives print("\n=== Training positives ===") with open(args.sequences, 'rb') as f: sequences = pickle.load(f) np.random.seed(42) idx = np.random.choice(len(sequences), min(args.n_train_sample, len(sequences)), replace=False) train_embs = extract_batch(model, [sequences[i] for i in idx], device=args.device) # 2. Negatives by difficulty print("\n=== Negatives ===") with open(args.negatives, 'rb') as f: negatives = pickle.load(f) easy = [n for n in negatives if n.get('difficulty_category') == 'easy'] medium = [n for n in negatives if n.get('difficulty_category') == 'medium'] hard = [n for n in negatives if n.get('difficulty_category') == 'hard'] n_neg = args.n_neg_sample easy_embs = extract_batch(model, [easy[i] for i in np.random.choice(len(easy), min(n_neg, len(easy)), replace=False)], device=args.device) medium_embs = extract_batch(model, [medium[i] for i in np.random.choice(len(medium), min(n_neg, len(medium)), replace=False)], device=args.device) hard_embs = extract_batch(model, [hard[i] for i in np.random.choice(len(hard), min(n_neg, len(hard)), replace=False)], device=args.device) # 3. Benchmark glycans print("\n=== Benchmark ===") iupac_list, taxonomy_labels = load_benchmark_data(args.benchmark_csv) matched, matched_idx = iupac_to_tokenized(iupac_list, sequences) if matched: benchmark_embs = extract_batch(model, matched, device=args.device) benchmark_labels = {k: [taxonomy_labels[k][i] for i in matched_idx] for k in taxonomy_labels} else: benchmark_embs = np.zeros((0, 768)) benchmark_labels = {k: [] for k in taxonomy_labels} # Save out = os.path.join(args.output_dir, f'embeddings_{args.name}.npz') np.savez_compressed(out, train_embs=train_embs, easy_embs=easy_embs, medium_embs=medium_embs, hard_embs=hard_embs, benchmark_embs=benchmark_embs, benchmark_kingdom=np.array(benchmark_labels['kingdom']), benchmark_phylum=np.array(benchmark_labels['phylum']), benchmark_class=np.array(benchmark_labels['class']), benchmark_split=np.array(benchmark_labels['split']), ) print(f"\nSaved: {out}") for k in ['train_embs', 'easy_embs', 'medium_embs', 'hard_embs', 'benchmark_embs']: print(f" {k}: {eval(k).shape}") json.dump({'model': args.name, 'n_train': len(train_embs), 'n_easy': len(easy_embs), 'n_medium': len(medium_embs), 'n_hard': len(hard_embs), 'n_benchmark': len(benchmark_embs)}, open(os.path.join(args.output_dir, f'summary_{args.name}.json'), 'w'), indent=2) if __name__ == '__main__': main()