#!/usr/bin/env python3 """Analyze confidence distribution for ambiguous BPE tokens.""" import sys import torch import torch.nn.functional as F from pathlib import Path import pickle import json import numpy as np from collections import Counter from tqdm import tqdm sys.path.insert(0, str(Path(__file__).parent.parent)) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig def main(): device = torch.device("cuda") # Load vocab and ambiguity with open('../data/bpe_vocabulary_clean.json', 'r') as f: vocab = json.load(f) token_to_id = vocab.get('token_to_id', vocab) id_to_token = {v: k for k, v in token_to_id.items()} with open('../data/bpe_ambiguity_tokens.json', 'r') as f: ambig_data = json.load(f) ambig_ids = set(ambig_data.get('ambiguous_ids', [])) # Load model print("Loading model...") checkpoint = torch.load('../checkpoints_v4_bpe_topology/best_v4_bpe_model.pt', map_location=device) cfg = checkpoint['config']['model'] if 'config' in checkpoint else None if cfg: config = MultimodalGlycanBERTConfig( seq_vocab_size=cfg['sequence']['vocab_size'], seq_hidden_size=cfg['sequence']['hidden_size'], seq_num_layers=cfg['sequence']['num_hidden_layers'], seq_num_heads=cfg['sequence']['num_attention_heads'], seq_max_length=cfg['sequence']['max_length'], ms_vocab_size=cfg['mass_spectrometry']['vocab_size'], ms_hidden_size=cfg['mass_spectrometry']['hidden_size'], ms_num_layers=cfg['mass_spectrometry']['num_hidden_layers'], ms_num_heads=cfg['mass_spectrometry']['num_attention_heads'], ms_max_length=cfg['mass_spectrometry']['max_length'], struct_vocab_size=cfg['structure_3d']['vocab_size'], struct_hidden_size=cfg['structure_3d']['hidden_size'], struct_num_layers=cfg['structure_3d']['num_hidden_layers'], struct_num_heads=cfg['structure_3d']['num_attention_heads'], struct_max_length=cfg['structure_3d']['max_length'], fusion_hidden_size=cfg['fusion']['fusion_hidden_size'], fusion_num_layers=cfg['fusion']['fusion_num_layers'], ) model = MultimodalGlycanBERT(config) model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint) model.to(device) model.eval() # Load sequences (sample) print("Loading sequences...") with open('../data/sequences_bpe.pkl', 'rb') as f: seqs = pickle.load(f) if isinstance(seqs, dict): seqs = list(seqs.values()) # Sample 10k sequences for analysis import random random.seed(42) sample = random.sample([s for s in seqs if any(t in ambig_ids for t in s.get('token_ids', []))], min(10000, len(seqs))) print(f"Analyzing {len(sample)} sequences...") all_confidences = [] max_length = 256 pad_id = token_to_id.get('[PAD]', 0) mask_id = token_to_id.get('[MASK]', 4) with torch.no_grad(): for seq in tqdm(sample): token_ids = list(seq.get('token_ids', []))[:max_length] ambig_positions = [i for i, t in enumerate(token_ids) if t in ambig_ids] if not ambig_positions: continue # Pad attention_mask = [1] * len(token_ids) + [0] * (max_length - len(token_ids)) token_ids = token_ids + [pad_id] * (max_length - len(token_ids)) # Mask ambiguous masked = token_ids.copy() for pos in ambig_positions: masked[pos] = mask_id # Forward seq_t = torch.tensor([masked], dtype=torch.long, device=device) att_t = torch.tensor([attention_mask], dtype=torch.long, device=device) res_t = torch.zeros_like(seq_t) outputs = model( seq_token_ids=seq_t, seq_attention_mask=att_t, seq_residue_ids=res_t, ms_token_ids=torch.zeros(1, 150, dtype=torch.long, device=device), ms_attention_mask=torch.zeros(1, 150, dtype=torch.long, device=device), has_ms=torch.zeros(1, dtype=torch.bool, device=device), struct_token_ids=torch.zeros(1, 200, dtype=torch.long, device=device), struct_attention_mask=torch.zeros(1, 200, dtype=torch.long, device=device), struct_residue_ids=torch.full((1, 200), -1, dtype=torch.long, device=device), has_3d=torch.zeros(1, dtype=torch.bool, device=device), return_dict=True, ) logits = outputs['seq_logits'][0] for pos in ambig_positions: probs = F.softmax(logits[pos], dim=-1) conf = probs.max().item() pred = probs.argmax().item() all_confidences.append({ 'confidence': conf, 'original': token_ids[pos], 'predicted': pred, 'pred_is_valid': pred not in ambig_ids }) # Analyze confs = [c['confidence'] for c in all_confidences] valid_confs = [c['confidence'] for c in all_confidences if c['pred_is_valid']] print("\n" + "="*60) print("CONFIDENCE DISTRIBUTION ANALYSIS") print("="*60) print(f"Total ambiguous tokens analyzed: {len(confs)}") print(f"Predictions to valid tokens: {len(valid_confs)} ({100*len(valid_confs)/len(confs):.1f}%)") # Histogram bins = [0.0, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 1.0] print("\nConfidence Distribution (valid predictions only):") for i in range(len(bins)-1): cnt = sum(1 for c in valid_confs if bins[i] <= c < bins[i+1]) pct = 100 * cnt / len(valid_confs) if valid_confs else 0 bar = "█" * int(pct/2) print(f" [{bins[i]:.2f}-{bins[i+1]:.2f}): {cnt:6d} ({pct:5.1f}%) {bar}") # Cumulative print("\nCumulative Resolution by Threshold:") for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]: above = sum(1 for c in valid_confs if c >= thresh) pct = 100 * above / len(confs) if confs else 0 print(f" >= {thresh}: {above:6d} tokens ({pct:5.1f}% of total)") print("="*60) if __name__ == "__main__": main()