File size: 9,535 Bytes
1d6f391 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python3
"""
Extract [CLS] Embeddings for Embedding Space Deep Dive
Extracts embeddings from V5 and V6 checkpoints for multiple data subsets:
- Training positives (sampled)
- Impossible negatives (easy/medium/hard)
- Benchmark test glycans (with taxonomy labels)
Output: .npz files with embeddings + metadata for analysis.
"""
import os, sys, torch, pickle, json, csv, argparse
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from tqdm import tqdm
project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training')
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
def load_model(checkpoint_path, device='cuda'):
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
state_dict = checkpoint.get('model_state_dict', checkpoint)
# Strip projection head keys (V6)
backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')}
n_stripped = len(state_dict) - len(backbone_sd)
if n_stripped > 0:
print(f" Stripped {n_stripped} projection head keys")
vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0]
config_kwargs = dict(
seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12,
seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3,
)
if 'ms_embeddings.token_embeddings.weight' in backbone_sd:
ms_total = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0]
config_kwargs['ms_vocab_size'] = ms_total - vocab_size
config = MultimodalGlycanBERTConfig(**config_kwargs)
model = MultimodalGlycanBERT(config)
model.load_state_dict(backbone_sd, strict=False)
model.to(device).eval()
print(f" Model: {sum(p.numel() for p in model.parameters()):,} params")
return model
def get_cls_embedding(model, token_ids, branch_depths, linkage_types, device='cuda', max_len=256):
with torch.no_grad():
if not isinstance(token_ids, torch.Tensor):
token_ids = torch.tensor(token_ids, dtype=torch.long)
if not isinstance(branch_depths, torch.Tensor):
branch_depths = torch.tensor(branch_depths, dtype=torch.long)
if not isinstance(linkage_types, torch.Tensor):
linkage_types = torch.tensor(linkage_types, dtype=torch.long)
token_ids, branch_depths, linkage_types = token_ids.flatten(), branch_depths.flatten(), linkage_types.flatten()
min_len = min(len(token_ids), len(branch_depths), len(linkage_types))
token_ids, branch_depths, linkage_types = token_ids[:min_len], branch_depths[:min_len], linkage_types[:min_len]
if min_len > max_len:
token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len]
if len(token_ids) < max_len:
pad_len = max_len - len(token_ids)
token_ids = F.pad(token_ids, (0, pad_len), value=0)
branch_depths = F.pad(branch_depths, (0, pad_len), value=0)
linkage_types = F.pad(linkage_types, (0, pad_len), value=0)
token_ids = token_ids.unsqueeze(0).to(device)
branch_depths = branch_depths.unsqueeze(0).to(device)
linkage_types = linkage_types.unsqueeze(0).to(device)
seq_out = model.seq_embeddings(token_ids, branch_depths=branch_depths, linkage_types=linkage_types)
return seq_out[:, 0, :].cpu().numpy().flatten()
def extract_batch(model, samples, device='cuda'):
all_embs = []
for i, sample in enumerate(tqdm(samples, desc="Extracting")):
token_ids = sample.get('token_ids', sample.get('tokens', []))
if isinstance(token_ids, str): token_ids = eval(token_ids)
branch_depths = sample.get('branch_depths', [0] * len(token_ids))
if isinstance(branch_depths, str): branch_depths = eval(branch_depths)
linkage_types = sample.get('linkage_types', [0] * len(token_ids))
if isinstance(linkage_types, str): linkage_types = eval(linkage_types)
try:
emb = get_cls_embedding(model, token_ids, branch_depths, linkage_types, device=device)
all_embs.append(emb)
except Exception as e:
print(f" Error sample {i}: {e}")
all_embs.append(np.zeros(768))
return np.array(all_embs)
def load_benchmark_data(csv_path):
print(f"Loading benchmark from {csv_path}...")
iupac_list, labels = [], {'kingdom': [], 'phylum': [], 'class': [], 'split': []}
with open(csv_path, 'r') as f:
for row in csv.DictReader(f):
iupac_list.append(row['target'])
labels['kingdom'].append(row.get('kingdom', ''))
labels['phylum'].append(row.get('phylum', ''))
labels['class'].append(row.get('class', ''))
if row.get('train', '').lower() == 'true': labels['split'].append('train')
elif row.get('validation', '').lower() == 'true': labels['split'].append('val')
elif row.get('test', '').lower() == 'true': labels['split'].append('test')
else: labels['split'].append('unknown')
print(f" {len(iupac_list)} samples")
return iupac_list, labels
def iupac_to_tokenized(iupac_list, sequences_data):
lookup = {s.get('iupac_name', ''): s for s in sequences_data if s.get('iupac_name')}
matched, indices = [], []
for idx, iupac in enumerate(iupac_list):
if iupac in lookup:
matched.append(lookup[iupac])
indices.append(idx)
print(f" Matched {len(matched)}/{len(iupac_list)} IUPAC strings")
return matched, indices
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--name', required=True, help='v5 or v6')
parser.add_argument('--sequences', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl')
parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_scored.pkl')
parser.add_argument('--benchmark_csv', default='bert_training_v4/downstream_tasks/baseline_data_strict/glycanml/glycan_classification.csv')
parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis')
parser.add_argument('--n_train_sample', type=int, default=10000)
parser.add_argument('--n_neg_sample', type=int, default=5000)
parser.add_argument('--device', default='cuda')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
model = load_model(args.checkpoint, device=args.device)
# 1. Training positives
print("\n=== Training positives ===")
with open(args.sequences, 'rb') as f:
sequences = pickle.load(f)
np.random.seed(42)
idx = np.random.choice(len(sequences), min(args.n_train_sample, len(sequences)), replace=False)
train_embs = extract_batch(model, [sequences[i] for i in idx], device=args.device)
# 2. Negatives by difficulty
print("\n=== Negatives ===")
with open(args.negatives, 'rb') as f:
negatives = pickle.load(f)
easy = [n for n in negatives if n.get('difficulty_category') == 'easy']
medium = [n for n in negatives if n.get('difficulty_category') == 'medium']
hard = [n for n in negatives if n.get('difficulty_category') == 'hard']
n_neg = args.n_neg_sample
easy_embs = extract_batch(model, [easy[i] for i in np.random.choice(len(easy), min(n_neg, len(easy)), replace=False)], device=args.device)
medium_embs = extract_batch(model, [medium[i] for i in np.random.choice(len(medium), min(n_neg, len(medium)), replace=False)], device=args.device)
hard_embs = extract_batch(model, [hard[i] for i in np.random.choice(len(hard), min(n_neg, len(hard)), replace=False)], device=args.device)
# 3. Benchmark glycans
print("\n=== Benchmark ===")
iupac_list, taxonomy_labels = load_benchmark_data(args.benchmark_csv)
matched, matched_idx = iupac_to_tokenized(iupac_list, sequences)
if matched:
benchmark_embs = extract_batch(model, matched, device=args.device)
benchmark_labels = {k: [taxonomy_labels[k][i] for i in matched_idx] for k in taxonomy_labels}
else:
benchmark_embs = np.zeros((0, 768))
benchmark_labels = {k: [] for k in taxonomy_labels}
# Save
out = os.path.join(args.output_dir, f'embeddings_{args.name}.npz')
np.savez_compressed(out,
train_embs=train_embs, easy_embs=easy_embs, medium_embs=medium_embs, hard_embs=hard_embs,
benchmark_embs=benchmark_embs,
benchmark_kingdom=np.array(benchmark_labels['kingdom']),
benchmark_phylum=np.array(benchmark_labels['phylum']),
benchmark_class=np.array(benchmark_labels['class']),
benchmark_split=np.array(benchmark_labels['split']),
)
print(f"\nSaved: {out}")
for k in ['train_embs', 'easy_embs', 'medium_embs', 'hard_embs', 'benchmark_embs']:
print(f" {k}: {eval(k).shape}")
json.dump({'model': args.name, 'n_train': len(train_embs), 'n_easy': len(easy_embs),
'n_medium': len(medium_embs), 'n_hard': len(hard_embs), 'n_benchmark': len(benchmark_embs)},
open(os.path.join(args.output_dir, f'summary_{args.name}.json'), 'w'), indent=2)
if __name__ == '__main__':
main()
|