| |
| """ |
| Extract [CLS] Embeddings for Embedding Space Deep Dive |
| |
| Extracts embeddings from V5 and V6 checkpoints for multiple data subsets: |
| - Training positives (sampled) |
| - Impossible negatives (easy/medium/hard) |
| - Benchmark test glycans (with taxonomy labels) |
| |
| Output: .npz files with embeddings + metadata for analysis. |
| """ |
|
|
| import os, sys, torch, pickle, json, csv, argparse |
| import torch.nn.functional as F |
| import numpy as np |
| from pathlib import Path |
| from tqdm import tqdm |
|
|
| project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training') |
| sys.path.insert(0, str(project_root)) |
| sys.path.insert(0, str(project_root / 'bert_training_v4')) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
|
|
|
|
| def load_model(checkpoint_path, device='cuda'): |
| print(f"Loading model from {checkpoint_path}...") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| state_dict = checkpoint.get('model_state_dict', checkpoint) |
| |
| |
| backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')} |
| n_stripped = len(state_dict) - len(backbone_sd) |
| if n_stripped > 0: |
| print(f" Stripped {n_stripped} projection head keys") |
| |
| vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0] |
| config_kwargs = dict( |
| seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, |
| seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3, |
| ) |
| if 'ms_embeddings.token_embeddings.weight' in backbone_sd: |
| ms_total = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0] |
| config_kwargs['ms_vocab_size'] = ms_total - vocab_size |
| |
| config = MultimodalGlycanBERTConfig(**config_kwargs) |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(backbone_sd, strict=False) |
| model.to(device).eval() |
| print(f" Model: {sum(p.numel() for p in model.parameters()):,} params") |
| return model |
|
|
|
|
| def get_cls_embedding(model, token_ids, branch_depths, linkage_types, device='cuda', max_len=256): |
| with torch.no_grad(): |
| if not isinstance(token_ids, torch.Tensor): |
| token_ids = torch.tensor(token_ids, dtype=torch.long) |
| if not isinstance(branch_depths, torch.Tensor): |
| branch_depths = torch.tensor(branch_depths, dtype=torch.long) |
| if not isinstance(linkage_types, torch.Tensor): |
| linkage_types = torch.tensor(linkage_types, dtype=torch.long) |
| |
| token_ids, branch_depths, linkage_types = token_ids.flatten(), branch_depths.flatten(), linkage_types.flatten() |
| min_len = min(len(token_ids), len(branch_depths), len(linkage_types)) |
| token_ids, branch_depths, linkage_types = token_ids[:min_len], branch_depths[:min_len], linkage_types[:min_len] |
| |
| if min_len > max_len: |
| token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len] |
| if len(token_ids) < max_len: |
| pad_len = max_len - len(token_ids) |
| token_ids = F.pad(token_ids, (0, pad_len), value=0) |
| branch_depths = F.pad(branch_depths, (0, pad_len), value=0) |
| linkage_types = F.pad(linkage_types, (0, pad_len), value=0) |
| |
| token_ids = token_ids.unsqueeze(0).to(device) |
| branch_depths = branch_depths.unsqueeze(0).to(device) |
| linkage_types = linkage_types.unsqueeze(0).to(device) |
| |
| seq_out = model.seq_embeddings(token_ids, branch_depths=branch_depths, linkage_types=linkage_types) |
| return seq_out[:, 0, :].cpu().numpy().flatten() |
|
|
|
|
| def extract_batch(model, samples, device='cuda'): |
| all_embs = [] |
| for i, sample in enumerate(tqdm(samples, desc="Extracting")): |
| token_ids = sample.get('token_ids', sample.get('tokens', [])) |
| if isinstance(token_ids, str): token_ids = eval(token_ids) |
| branch_depths = sample.get('branch_depths', [0] * len(token_ids)) |
| if isinstance(branch_depths, str): branch_depths = eval(branch_depths) |
| linkage_types = sample.get('linkage_types', [0] * len(token_ids)) |
| if isinstance(linkage_types, str): linkage_types = eval(linkage_types) |
| |
| try: |
| emb = get_cls_embedding(model, token_ids, branch_depths, linkage_types, device=device) |
| all_embs.append(emb) |
| except Exception as e: |
| print(f" Error sample {i}: {e}") |
| all_embs.append(np.zeros(768)) |
| return np.array(all_embs) |
|
|
|
|
| def load_benchmark_data(csv_path): |
| print(f"Loading benchmark from {csv_path}...") |
| iupac_list, labels = [], {'kingdom': [], 'phylum': [], 'class': [], 'split': []} |
| with open(csv_path, 'r') as f: |
| for row in csv.DictReader(f): |
| iupac_list.append(row['target']) |
| labels['kingdom'].append(row.get('kingdom', '')) |
| labels['phylum'].append(row.get('phylum', '')) |
| labels['class'].append(row.get('class', '')) |
| if row.get('train', '').lower() == 'true': labels['split'].append('train') |
| elif row.get('validation', '').lower() == 'true': labels['split'].append('val') |
| elif row.get('test', '').lower() == 'true': labels['split'].append('test') |
| else: labels['split'].append('unknown') |
| print(f" {len(iupac_list)} samples") |
| return iupac_list, labels |
|
|
|
|
| def iupac_to_tokenized(iupac_list, sequences_data): |
| lookup = {s.get('iupac_name', ''): s for s in sequences_data if s.get('iupac_name')} |
| matched, indices = [], [] |
| for idx, iupac in enumerate(iupac_list): |
| if iupac in lookup: |
| matched.append(lookup[iupac]) |
| indices.append(idx) |
| print(f" Matched {len(matched)}/{len(iupac_list)} IUPAC strings") |
| return matched, indices |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--checkpoint', required=True) |
| parser.add_argument('--name', required=True, help='v5 or v6') |
| parser.add_argument('--sequences', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl') |
| parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_scored.pkl') |
| parser.add_argument('--benchmark_csv', default='bert_training_v4/downstream_tasks/baseline_data_strict/glycanml/glycan_classification.csv') |
| parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis') |
| parser.add_argument('--n_train_sample', type=int, default=10000) |
| parser.add_argument('--n_neg_sample', type=int, default=5000) |
| parser.add_argument('--device', default='cuda') |
| args = parser.parse_args() |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| model = load_model(args.checkpoint, device=args.device) |
| |
| |
| print("\n=== Training positives ===") |
| with open(args.sequences, 'rb') as f: |
| sequences = pickle.load(f) |
| np.random.seed(42) |
| idx = np.random.choice(len(sequences), min(args.n_train_sample, len(sequences)), replace=False) |
| train_embs = extract_batch(model, [sequences[i] for i in idx], device=args.device) |
| |
| |
| print("\n=== Negatives ===") |
| with open(args.negatives, 'rb') as f: |
| negatives = pickle.load(f) |
| easy = [n for n in negatives if n.get('difficulty_category') == 'easy'] |
| medium = [n for n in negatives if n.get('difficulty_category') == 'medium'] |
| hard = [n for n in negatives if n.get('difficulty_category') == 'hard'] |
| n_neg = args.n_neg_sample |
| easy_embs = extract_batch(model, [easy[i] for i in np.random.choice(len(easy), min(n_neg, len(easy)), replace=False)], device=args.device) |
| medium_embs = extract_batch(model, [medium[i] for i in np.random.choice(len(medium), min(n_neg, len(medium)), replace=False)], device=args.device) |
| hard_embs = extract_batch(model, [hard[i] for i in np.random.choice(len(hard), min(n_neg, len(hard)), replace=False)], device=args.device) |
| |
| |
| print("\n=== Benchmark ===") |
| iupac_list, taxonomy_labels = load_benchmark_data(args.benchmark_csv) |
| matched, matched_idx = iupac_to_tokenized(iupac_list, sequences) |
| if matched: |
| benchmark_embs = extract_batch(model, matched, device=args.device) |
| benchmark_labels = {k: [taxonomy_labels[k][i] for i in matched_idx] for k in taxonomy_labels} |
| else: |
| benchmark_embs = np.zeros((0, 768)) |
| benchmark_labels = {k: [] for k in taxonomy_labels} |
| |
| |
| out = os.path.join(args.output_dir, f'embeddings_{args.name}.npz') |
| np.savez_compressed(out, |
| train_embs=train_embs, easy_embs=easy_embs, medium_embs=medium_embs, hard_embs=hard_embs, |
| benchmark_embs=benchmark_embs, |
| benchmark_kingdom=np.array(benchmark_labels['kingdom']), |
| benchmark_phylum=np.array(benchmark_labels['phylum']), |
| benchmark_class=np.array(benchmark_labels['class']), |
| benchmark_split=np.array(benchmark_labels['split']), |
| ) |
| print(f"\nSaved: {out}") |
| for k in ['train_embs', 'easy_embs', 'medium_embs', 'hard_embs', 'benchmark_embs']: |
| print(f" {k}: {eval(k).shape}") |
| |
| json.dump({'model': args.name, 'n_train': len(train_embs), 'n_easy': len(easy_embs), |
| 'n_medium': len(medium_embs), 'n_hard': len(hard_embs), 'n_benchmark': len(benchmark_embs)}, |
| open(os.path.join(args.output_dir, f'summary_{args.name}.json'), 'w'), indent=2) |
|
|
| if __name__ == '__main__': |
| main() |
|
|