bertose-affinose-training-code / code /probes /extract_embeddings.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
9.54 kB
#!/usr/bin/env python3
"""
Extract [CLS] Embeddings for Embedding Space Deep Dive
Extracts embeddings from V5 and V6 checkpoints for multiple data subsets:
- Training positives (sampled)
- Impossible negatives (easy/medium/hard)
- Benchmark test glycans (with taxonomy labels)
Output: .npz files with embeddings + metadata for analysis.
"""
import os, sys, torch, pickle, json, csv, argparse
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from tqdm import tqdm
project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training')
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
def load_model(checkpoint_path, device='cuda'):
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
state_dict = checkpoint.get('model_state_dict', checkpoint)
# Strip projection head keys (V6)
backbone_sd = {k: v for k, v in state_dict.items() if not k.startswith('proj_head.')}
n_stripped = len(state_dict) - len(backbone_sd)
if n_stripped > 0:
print(f" Stripped {n_stripped} projection head keys")
vocab_size = backbone_sd['seq_embeddings.token_embeddings.weight'].shape[0]
config_kwargs = dict(
seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12,
seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3,
)
if 'ms_embeddings.token_embeddings.weight' in backbone_sd:
ms_total = backbone_sd['ms_embeddings.token_embeddings.weight'].shape[0]
config_kwargs['ms_vocab_size'] = ms_total - vocab_size
config = MultimodalGlycanBERTConfig(**config_kwargs)
model = MultimodalGlycanBERT(config)
model.load_state_dict(backbone_sd, strict=False)
model.to(device).eval()
print(f" Model: {sum(p.numel() for p in model.parameters()):,} params")
return model
def get_cls_embedding(model, token_ids, branch_depths, linkage_types, device='cuda', max_len=256):
with torch.no_grad():
if not isinstance(token_ids, torch.Tensor):
token_ids = torch.tensor(token_ids, dtype=torch.long)
if not isinstance(branch_depths, torch.Tensor):
branch_depths = torch.tensor(branch_depths, dtype=torch.long)
if not isinstance(linkage_types, torch.Tensor):
linkage_types = torch.tensor(linkage_types, dtype=torch.long)
token_ids, branch_depths, linkage_types = token_ids.flatten(), branch_depths.flatten(), linkage_types.flatten()
min_len = min(len(token_ids), len(branch_depths), len(linkage_types))
token_ids, branch_depths, linkage_types = token_ids[:min_len], branch_depths[:min_len], linkage_types[:min_len]
if min_len > max_len:
token_ids, branch_depths, linkage_types = token_ids[:max_len], branch_depths[:max_len], linkage_types[:max_len]
if len(token_ids) < max_len:
pad_len = max_len - len(token_ids)
token_ids = F.pad(token_ids, (0, pad_len), value=0)
branch_depths = F.pad(branch_depths, (0, pad_len), value=0)
linkage_types = F.pad(linkage_types, (0, pad_len), value=0)
token_ids = token_ids.unsqueeze(0).to(device)
branch_depths = branch_depths.unsqueeze(0).to(device)
linkage_types = linkage_types.unsqueeze(0).to(device)
seq_out = model.seq_embeddings(token_ids, branch_depths=branch_depths, linkage_types=linkage_types)
return seq_out[:, 0, :].cpu().numpy().flatten()
def extract_batch(model, samples, device='cuda'):
all_embs = []
for i, sample in enumerate(tqdm(samples, desc="Extracting")):
token_ids = sample.get('token_ids', sample.get('tokens', []))
if isinstance(token_ids, str): token_ids = eval(token_ids)
branch_depths = sample.get('branch_depths', [0] * len(token_ids))
if isinstance(branch_depths, str): branch_depths = eval(branch_depths)
linkage_types = sample.get('linkage_types', [0] * len(token_ids))
if isinstance(linkage_types, str): linkage_types = eval(linkage_types)
try:
emb = get_cls_embedding(model, token_ids, branch_depths, linkage_types, device=device)
all_embs.append(emb)
except Exception as e:
print(f" Error sample {i}: {e}")
all_embs.append(np.zeros(768))
return np.array(all_embs)
def load_benchmark_data(csv_path):
print(f"Loading benchmark from {csv_path}...")
iupac_list, labels = [], {'kingdom': [], 'phylum': [], 'class': [], 'split': []}
with open(csv_path, 'r') as f:
for row in csv.DictReader(f):
iupac_list.append(row['target'])
labels['kingdom'].append(row.get('kingdom', ''))
labels['phylum'].append(row.get('phylum', ''))
labels['class'].append(row.get('class', ''))
if row.get('train', '').lower() == 'true': labels['split'].append('train')
elif row.get('validation', '').lower() == 'true': labels['split'].append('val')
elif row.get('test', '').lower() == 'true': labels['split'].append('test')
else: labels['split'].append('unknown')
print(f" {len(iupac_list)} samples")
return iupac_list, labels
def iupac_to_tokenized(iupac_list, sequences_data):
lookup = {s.get('iupac_name', ''): s for s in sequences_data if s.get('iupac_name')}
matched, indices = [], []
for idx, iupac in enumerate(iupac_list):
if iupac in lookup:
matched.append(lookup[iupac])
indices.append(idx)
print(f" Matched {len(matched)}/{len(iupac_list)} IUPAC strings")
return matched, indices
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--name', required=True, help='v5 or v6')
parser.add_argument('--sequences', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl')
parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_scored.pkl')
parser.add_argument('--benchmark_csv', default='bert_training_v4/downstream_tasks/baseline_data_strict/glycanml/glycan_classification.csv')
parser.add_argument('--output_dir', default='bert_v6_contrastive/analysis')
parser.add_argument('--n_train_sample', type=int, default=10000)
parser.add_argument('--n_neg_sample', type=int, default=5000)
parser.add_argument('--device', default='cuda')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
model = load_model(args.checkpoint, device=args.device)
# 1. Training positives
print("\n=== Training positives ===")
with open(args.sequences, 'rb') as f:
sequences = pickle.load(f)
np.random.seed(42)
idx = np.random.choice(len(sequences), min(args.n_train_sample, len(sequences)), replace=False)
train_embs = extract_batch(model, [sequences[i] for i in idx], device=args.device)
# 2. Negatives by difficulty
print("\n=== Negatives ===")
with open(args.negatives, 'rb') as f:
negatives = pickle.load(f)
easy = [n for n in negatives if n.get('difficulty_category') == 'easy']
medium = [n for n in negatives if n.get('difficulty_category') == 'medium']
hard = [n for n in negatives if n.get('difficulty_category') == 'hard']
n_neg = args.n_neg_sample
easy_embs = extract_batch(model, [easy[i] for i in np.random.choice(len(easy), min(n_neg, len(easy)), replace=False)], device=args.device)
medium_embs = extract_batch(model, [medium[i] for i in np.random.choice(len(medium), min(n_neg, len(medium)), replace=False)], device=args.device)
hard_embs = extract_batch(model, [hard[i] for i in np.random.choice(len(hard), min(n_neg, len(hard)), replace=False)], device=args.device)
# 3. Benchmark glycans
print("\n=== Benchmark ===")
iupac_list, taxonomy_labels = load_benchmark_data(args.benchmark_csv)
matched, matched_idx = iupac_to_tokenized(iupac_list, sequences)
if matched:
benchmark_embs = extract_batch(model, matched, device=args.device)
benchmark_labels = {k: [taxonomy_labels[k][i] for i in matched_idx] for k in taxonomy_labels}
else:
benchmark_embs = np.zeros((0, 768))
benchmark_labels = {k: [] for k in taxonomy_labels}
# Save
out = os.path.join(args.output_dir, f'embeddings_{args.name}.npz')
np.savez_compressed(out,
train_embs=train_embs, easy_embs=easy_embs, medium_embs=medium_embs, hard_embs=hard_embs,
benchmark_embs=benchmark_embs,
benchmark_kingdom=np.array(benchmark_labels['kingdom']),
benchmark_phylum=np.array(benchmark_labels['phylum']),
benchmark_class=np.array(benchmark_labels['class']),
benchmark_split=np.array(benchmark_labels['split']),
)
print(f"\nSaved: {out}")
for k in ['train_embs', 'easy_embs', 'medium_embs', 'hard_embs', 'benchmark_embs']:
print(f" {k}: {eval(k).shape}")
json.dump({'model': args.name, 'n_train': len(train_embs), 'n_easy': len(easy_embs),
'n_medium': len(medium_embs), 'n_hard': len(hard_embs), 'n_benchmark': len(benchmark_embs)},
open(os.path.join(args.output_dir, f'summary_{args.name}.json'), 'w'), indent=2)
if __name__ == '__main__':
main()