bertose-affinose-training-code / code /contrastive /score_negative_difficulty.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
7.99 kB
#!/usr/bin/env python3
"""
Score Negative Difficulty for V6 Curriculum Learning
Calculates similarity between each negative sample and positive samples
to categorize negatives as easy/medium/hard for curriculum learning.
"""
import os
import sys
import torch
import torch.nn.functional as F
import pickle
import json
import argparse
from pathlib import Path
from tqdm import tqdm
import numpy as np
import random
# Add project paths
project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training')
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'bert_training_v4'))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
def load_model(checkpoint_path, device='cuda'):
"""Load V5-A MultimodalGlycanBERT model."""
print(f"Loading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Get state dict
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Get vocab size from state_dict
vocab_size = state_dict['seq_embeddings.token_embeddings.weight'].shape[0]
# Create config matching the checkpoint (using benchmark script pattern)
config = MultimodalGlycanBERTConfig(
seq_vocab_size=vocab_size,
seq_hidden_size=768,
seq_num_layers=12,
seq_num_heads=12,
seq_max_length=256,
use_cnn_frontend=True,
cnn_kernel_size=3
)
model = MultimodalGlycanBERT(config)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
return model, config
def get_embedding(model, sample, device='cuda', max_len=256):
"""Get [CLS] embedding for a sample using sequence encoder only."""
with torch.no_grad():
# Parse token data
token_ids = sample.get('token_ids', sample.get('tokens', []))
if isinstance(token_ids, str):
token_ids = eval(token_ids)
token_ids = torch.tensor(token_ids).unsqueeze(0).to(device)
# Get or create other inputs
branch_depths = sample.get('branch_depths', [0] * len(token_ids[0]))
if isinstance(branch_depths, str):
branch_depths = eval(branch_depths)
branch_depths = torch.tensor(branch_depths).unsqueeze(0).to(device)
linkage_types = sample.get('linkage_types', [0] * len(token_ids[0]))
if isinstance(linkage_types, str):
linkage_types = eval(linkage_types)
linkage_types = torch.tensor(linkage_types).unsqueeze(0).to(device)
# Truncate or pad to max_len
if token_ids.size(1) > max_len:
token_ids = token_ids[:, :max_len]
branch_depths = branch_depths[:, :max_len]
linkage_types = linkage_types[:, :max_len]
elif token_ids.size(1) < max_len:
pad_len = max_len - token_ids.size(1)
token_ids = F.pad(token_ids, (0, pad_len), value=0)
branch_depths = F.pad(branch_depths, (0, pad_len), value=0)
linkage_types = F.pad(linkage_types, (0, pad_len), value=0)
# Get sequence embedding through encoder
x = model.seq_embeddings(token_ids, branch_depths, linkage_types)
for layer in model.seq_layers:
x = layer(x)
# Return [CLS] token embedding (first token)
return x[:, 0, :].squeeze(0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_150k.pkl')
parser.add_argument('--positives', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl')
parser.add_argument('--checkpoint', default='checkpoints_v5_bpe_topo/best_v5_bpe_topo_model.pt')
parser.add_argument('--output', default='bert_v6_contrastive/data/negatives_scored.pkl')
parser.add_argument('--device', default='cuda')
parser.add_argument('--sample-pos', type=int, default=100, help='Number of positive samples to compare against')
args = parser.parse_args()
# Load negatives
print("Loading negatives...")
with open(args.negatives, 'rb') as f:
negatives = pickle.load(f)
print(f" Loaded {len(negatives):,} negatives")
# Load positives
print("Loading positives...")
with open(args.positives, 'rb') as f:
positives = pickle.load(f)
if isinstance(positives, dict):
positives = list(positives.values())
print(f" Loaded {len(positives):,} positives")
# Load model
model, config = load_model(args.checkpoint, args.device)
# Pre-compute positive embeddings for comparison
print(f"Pre-computing {args.sample_pos} positive embeddings for comparison...")
sample_positives = random.sample(positives, min(args.sample_pos, len(positives)))
pos_embeddings = []
for pos in tqdm(sample_positives, desc="Positive embeddings"):
try:
emb = get_embedding(model, pos, args.device)
pos_embeddings.append(emb)
except Exception as e:
continue
pos_embeddings = torch.stack(pos_embeddings)
print(f" Got {len(pos_embeddings)} positive embeddings")
# Score each negative
print(f"\nScoring {len(negatives):,} negatives...")
scored = 0
errors = 0
for i, neg in enumerate(tqdm(negatives)):
try:
neg_emb = get_embedding(model, neg, args.device)
# Compare to all sampled positives
sims = F.cosine_similarity(neg_emb.unsqueeze(0), pos_embeddings, dim=1)
avg_sim = sims.mean().item()
max_sim = sims.max().item()
# Score based on similarity
neg['difficulty_score'] = avg_sim
neg['max_similarity'] = max_sim
# Categorize
if avg_sim < 0.3:
neg['difficulty_category'] = 'easy'
elif avg_sim < 0.6:
neg['difficulty_category'] = 'medium'
else:
neg['difficulty_category'] = 'hard'
scored += 1
except Exception as e:
neg['difficulty_score'] = 0.5
neg['difficulty_category'] = 'medium'
neg['error'] = str(e)
errors += 1
# Compute stats
easy = sum(1 for n in negatives if n.get('difficulty_category') == 'easy')
medium = sum(1 for n in negatives if n.get('difficulty_category') == 'medium')
hard = sum(1 for n in negatives if n.get('difficulty_category') == 'hard')
scores = [n['difficulty_score'] for n in negatives if 'difficulty_score' in n]
stats = {
'total': len(negatives),
'scored': scored,
'errors': errors,
'easy': easy,
'medium': medium,
'hard': hard,
'avg_score': float(np.mean(scores)) if scores else 0,
'std_score': float(np.std(scores)) if scores else 0,
}
print(f"\n=== Results ===")
print(f"Scored: {scored:,} / {len(negatives):,}")
print(f"Errors: {errors:,}")
print(f"Easy: {easy:,} ({100*easy/len(negatives):.1f}%)")
print(f"Medium: {medium:,} ({100*medium/len(negatives):.1f}%)")
print(f"Hard: {hard:,} ({100*hard/len(negatives):.1f}%)")
print(f"Avg Score: {stats['avg_score']:.4f} ± {stats['std_score']:.4f}")
# Save outputs
print(f"\nSaving scored negatives to {args.output}...")
os.makedirs(Path(args.output).parent, exist_ok=True)
with open(args.output, 'wb') as f:
pickle.dump(negatives, f)
stats_path = args.output.replace('.pkl', '_stats.json')
with open(stats_path, 'w') as f:
json.dump(stats, f, indent=2)
print(f"Saved stats to {stats_path}")
print("\nDone!")
if __name__ == '__main__':
main()