| |
| """ |
| Score Negative Difficulty for V6 Curriculum Learning |
| |
| Calculates similarity between each negative sample and positive samples |
| to categorize negatives as easy/medium/hard for curriculum learning. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
| import pickle |
| import json |
| import argparse |
| from pathlib import Path |
| from tqdm import tqdm |
| import numpy as np |
| import random |
|
|
| |
| project_root = Path('/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training') |
| sys.path.insert(0, str(project_root)) |
| sys.path.insert(0, str(project_root / 'bert_training_v4')) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
|
|
|
|
| def load_model(checkpoint_path, device='cuda'): |
| """Load V5-A MultimodalGlycanBERT model.""" |
| print(f"Loading model from {checkpoint_path}...") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| |
| |
| if 'model_state_dict' in checkpoint: |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| |
| vocab_size = state_dict['seq_embeddings.token_embeddings.weight'].shape[0] |
| |
| |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=vocab_size, |
| seq_hidden_size=768, |
| seq_num_layers=12, |
| seq_num_heads=12, |
| seq_max_length=256, |
| use_cnn_frontend=True, |
| cnn_kernel_size=3 |
| ) |
| |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(state_dict, strict=False) |
| model.to(device) |
| model.eval() |
| print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") |
| return model, config |
|
|
|
|
| def get_embedding(model, sample, device='cuda', max_len=256): |
| """Get [CLS] embedding for a sample using sequence encoder only.""" |
| with torch.no_grad(): |
| |
| token_ids = sample.get('token_ids', sample.get('tokens', [])) |
| if isinstance(token_ids, str): |
| token_ids = eval(token_ids) |
| token_ids = torch.tensor(token_ids).unsqueeze(0).to(device) |
| |
| |
| branch_depths = sample.get('branch_depths', [0] * len(token_ids[0])) |
| if isinstance(branch_depths, str): |
| branch_depths = eval(branch_depths) |
| branch_depths = torch.tensor(branch_depths).unsqueeze(0).to(device) |
| |
| linkage_types = sample.get('linkage_types', [0] * len(token_ids[0])) |
| if isinstance(linkage_types, str): |
| linkage_types = eval(linkage_types) |
| linkage_types = torch.tensor(linkage_types).unsqueeze(0).to(device) |
| |
| |
| if token_ids.size(1) > max_len: |
| token_ids = token_ids[:, :max_len] |
| branch_depths = branch_depths[:, :max_len] |
| linkage_types = linkage_types[:, :max_len] |
| elif token_ids.size(1) < max_len: |
| pad_len = max_len - token_ids.size(1) |
| token_ids = F.pad(token_ids, (0, pad_len), value=0) |
| branch_depths = F.pad(branch_depths, (0, pad_len), value=0) |
| linkage_types = F.pad(linkage_types, (0, pad_len), value=0) |
| |
| |
| x = model.seq_embeddings(token_ids, branch_depths, linkage_types) |
| for layer in model.seq_layers: |
| x = layer(x) |
| |
| |
| return x[:, 0, :].squeeze(0) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--negatives', default='bert_v6_contrastive/data/negatives_150k.pkl') |
| parser.add_argument('--positives', default='bert_v5_bpe_topo/data/sequences_bpe_expanded.pkl') |
| parser.add_argument('--checkpoint', default='checkpoints_v5_bpe_topo/best_v5_bpe_topo_model.pt') |
| parser.add_argument('--output', default='bert_v6_contrastive/data/negatives_scored.pkl') |
| parser.add_argument('--device', default='cuda') |
| parser.add_argument('--sample-pos', type=int, default=100, help='Number of positive samples to compare against') |
| args = parser.parse_args() |
| |
| |
| print("Loading negatives...") |
| with open(args.negatives, 'rb') as f: |
| negatives = pickle.load(f) |
| print(f" Loaded {len(negatives):,} negatives") |
| |
| |
| print("Loading positives...") |
| with open(args.positives, 'rb') as f: |
| positives = pickle.load(f) |
| if isinstance(positives, dict): |
| positives = list(positives.values()) |
| print(f" Loaded {len(positives):,} positives") |
| |
| |
| model, config = load_model(args.checkpoint, args.device) |
| |
| |
| print(f"Pre-computing {args.sample_pos} positive embeddings for comparison...") |
| sample_positives = random.sample(positives, min(args.sample_pos, len(positives))) |
| pos_embeddings = [] |
| for pos in tqdm(sample_positives, desc="Positive embeddings"): |
| try: |
| emb = get_embedding(model, pos, args.device) |
| pos_embeddings.append(emb) |
| except Exception as e: |
| continue |
| pos_embeddings = torch.stack(pos_embeddings) |
| print(f" Got {len(pos_embeddings)} positive embeddings") |
| |
| |
| print(f"\nScoring {len(negatives):,} negatives...") |
| scored = 0 |
| errors = 0 |
| for i, neg in enumerate(tqdm(negatives)): |
| try: |
| neg_emb = get_embedding(model, neg, args.device) |
| |
| |
| sims = F.cosine_similarity(neg_emb.unsqueeze(0), pos_embeddings, dim=1) |
| avg_sim = sims.mean().item() |
| max_sim = sims.max().item() |
| |
| |
| neg['difficulty_score'] = avg_sim |
| neg['max_similarity'] = max_sim |
| |
| |
| if avg_sim < 0.3: |
| neg['difficulty_category'] = 'easy' |
| elif avg_sim < 0.6: |
| neg['difficulty_category'] = 'medium' |
| else: |
| neg['difficulty_category'] = 'hard' |
| scored += 1 |
| |
| except Exception as e: |
| neg['difficulty_score'] = 0.5 |
| neg['difficulty_category'] = 'medium' |
| neg['error'] = str(e) |
| errors += 1 |
| |
| |
| easy = sum(1 for n in negatives if n.get('difficulty_category') == 'easy') |
| medium = sum(1 for n in negatives if n.get('difficulty_category') == 'medium') |
| hard = sum(1 for n in negatives if n.get('difficulty_category') == 'hard') |
| scores = [n['difficulty_score'] for n in negatives if 'difficulty_score' in n] |
| |
| stats = { |
| 'total': len(negatives), |
| 'scored': scored, |
| 'errors': errors, |
| 'easy': easy, |
| 'medium': medium, |
| 'hard': hard, |
| 'avg_score': float(np.mean(scores)) if scores else 0, |
| 'std_score': float(np.std(scores)) if scores else 0, |
| } |
| |
| print(f"\n=== Results ===") |
| print(f"Scored: {scored:,} / {len(negatives):,}") |
| print(f"Errors: {errors:,}") |
| print(f"Easy: {easy:,} ({100*easy/len(negatives):.1f}%)") |
| print(f"Medium: {medium:,} ({100*medium/len(negatives):.1f}%)") |
| print(f"Hard: {hard:,} ({100*hard/len(negatives):.1f}%)") |
| print(f"Avg Score: {stats['avg_score']:.4f} ± {stats['std_score']:.4f}") |
| |
| |
| print(f"\nSaving scored negatives to {args.output}...") |
| os.makedirs(Path(args.output).parent, exist_ok=True) |
| with open(args.output, 'wb') as f: |
| pickle.dump(negatives, f) |
| |
| stats_path = args.output.replace('.pkl', '_stats.json') |
| with open(stats_path, 'w') as f: |
| json.dump(stats, f, indent=2) |
| print(f"Saved stats to {stats_path}") |
| |
| print("\nDone!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|