#!/usr/bin/env python3 """ Score Negative Difficulty for V6 Curriculum Learning Calculates similarity between each negative sample and positive samples to categorize negatives as easy/medium/hard for curriculum learning. """ import os import sys import torch import torch.nn.functional as F import pickle import json import argparse from pathlib import Path from tqdm import tqdm import numpy as np import random # Add project paths project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training') sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root / 'bert_training_v4')) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig def load_model(checkpoint_path, device='cuda'): """Load V5-A MultimodalGlycanBERT model.""" print(f"Loading model from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # Get state dict if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint # Get vocab size from state_dict vocab_size = state_dict['seq_embeddings.token_embeddings.weight'].shape[0] # Create config matching the checkpoint (using benchmark script pattern) config = MultimodalGlycanBERTConfig( seq_vocab_size=vocab_size, seq_hidden_size=768, seq_num_layers=12, seq_num_heads=12, seq_max_length=256, use_cnn_frontend=True, cnn_kernel_size=3 ) model = MultimodalGlycanBERT(config) model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") return model, config def get_embedding(model, sample, device='cuda', max_len=256): """Get [CLS] embedding for a sample using sequence encoder only.""" with torch.no_grad(): # Parse token data token_ids = sample.get('token_ids', sample.get('tokens', [])) if isinstance(token_ids, str): token_ids = eval(token_ids) token_ids = torch.tensor(token_ids).unsqueeze(0).to(device) # Get or create other inputs branch_depths = sample.get('branch_depths', [0] * len(token_ids[0])) if isinstance(branch_depths, str): branch_depths = eval(branch_depths) branch_depths = torch.tensor(branch_depths).unsqueeze(0).to(device) linkage_types = sample.get('linkage_types', [0] * len(token_ids[0])) if isinstance(linkage_types, str): linkage_types = eval(linkage_types) linkage_types = torch.tensor(linkage_types).unsqueeze(0).to(device) # Truncate or pad to max_len if token_ids.size(1) > max_len: token_ids = token_ids[:, :max_len] branch_depths = branch_depths[:, :max_len] linkage_types = linkage_types[:, :max_len] elif token_ids.size(1) < max_len: pad_len = max_len - token_ids.size(1) token_ids = F.pad(token_ids, (0, pad_len), value=0) branch_depths = F.pad(branch_depths, (0, pad_len), value=0) linkage_types = F.pad(linkage_types, (0, pad_len), value=0) # Get sequence embedding through encoder x = model.seq_embeddings(token_ids, branch_depths, linkage_types) for layer in model.seq_layers: x = layer(x) # Return [CLS] token embedding (first token) return x[:, 0, :].squeeze(0) def main(): parser = argparse.ArgumentParser() parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_150k.pkl') parser.add_argument('--positives', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl') parser.add_argument('--checkpoint', default='checkpoints_v5_bpe_topo/best_v5_bpe_topo_model.pt') parser.add_argument('--output', default='bert_v6_contrastive/data/negatives_scored.pkl') parser.add_argument('--device', default='cuda') parser.add_argument('--sample-pos', type=int, default=100, help='Number of positive samples to compare against') args = parser.parse_args() # Load negatives print("Loading negatives...") with open(args.negatives, 'rb') as f: negatives = pickle.load(f) print(f" Loaded {len(negatives):,} negatives") # Load positives print("Loading positives...") with open(args.positives, 'rb') as f: positives = pickle.load(f) if isinstance(positives, dict): positives = list(positives.values()) print(f" Loaded {len(positives):,} positives") # Load model model, config = load_model(args.checkpoint, args.device) # Pre-compute positive embeddings for comparison print(f"Pre-computing {args.sample_pos} positive embeddings for comparison...") sample_positives = random.sample(positives, min(args.sample_pos, len(positives))) pos_embeddings = [] for pos in tqdm(sample_positives, desc="Positive embeddings"): try: emb = get_embedding(model, pos, args.device) pos_embeddings.append(emb) except Exception as e: continue pos_embeddings = torch.stack(pos_embeddings) print(f" Got {len(pos_embeddings)} positive embeddings") # Score each negative print(f"\nScoring {len(negatives):,} negatives...") scored = 0 errors = 0 for i, neg in enumerate(tqdm(negatives)): try: neg_emb = get_embedding(model, neg, args.device) # Compare to all sampled positives sims = F.cosine_similarity(neg_emb.unsqueeze(0), pos_embeddings, dim=1) avg_sim = sims.mean().item() max_sim = sims.max().item() # Score based on similarity neg['difficulty_score'] = avg_sim neg['max_similarity'] = max_sim # Categorize if avg_sim < 0.3: neg['difficulty_category'] = 'easy' elif avg_sim < 0.6: neg['difficulty_category'] = 'medium' else: neg['difficulty_category'] = 'hard' scored += 1 except Exception as e: neg['difficulty_score'] = 0.5 neg['difficulty_category'] = 'medium' neg['error'] = str(e) errors += 1 # Compute stats easy = sum(1 for n in negatives if n.get('difficulty_category') == 'easy') medium = sum(1 for n in negatives if n.get('difficulty_category') == 'medium') hard = sum(1 for n in negatives if n.get('difficulty_category') == 'hard') scores = [n['difficulty_score'] for n in negatives if 'difficulty_score' in n] stats = { 'total': len(negatives), 'scored': scored, 'errors': errors, 'easy': easy, 'medium': medium, 'hard': hard, 'avg_score': float(np.mean(scores)) if scores else 0, 'std_score': float(np.std(scores)) if scores else 0, } print(f"\n=== Results ===") print(f"Scored: {scored:,} / {len(negatives):,}") print(f"Errors: {errors:,}") print(f"Easy: {easy:,} ({100*easy/len(negatives):.1f}%)") print(f"Medium: {medium:,} ({100*medium/len(negatives):.1f}%)") print(f"Hard: {hard:,} ({100*hard/len(negatives):.1f}%)") print(f"Avg Score: {stats['avg_score']:.4f} ± {stats['std_score']:.4f}") # Save outputs print(f"\nSaving scored negatives to {args.output}...") os.makedirs(Path(args.output).parent, exist_ok=True) with open(args.output, 'wb') as f: pickle.dump(negatives, f) stats_path = args.output.replace('.pkl', '_stats.json') with open(stats_path, 'w') as f: json.dump(stats, f, indent=2) print(f"Saved stats to {stats_path}") print("\nDone!") if __name__ == '__main__': main()