| |
| """Analyze confidence distribution for ambiguous BPE tokens.""" |
| import sys |
| import torch |
| import torch.nn.functional as F |
| from pathlib import Path |
| import pickle |
| import json |
| import numpy as np |
| from collections import Counter |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
|
|
| def main(): |
| device = torch.device("cuda") |
| |
| |
| with open('../data/bpe_vocabulary_clean.json', 'r') as f: |
| vocab = json.load(f) |
| token_to_id = vocab.get('token_to_id', vocab) |
| id_to_token = {v: k for k, v in token_to_id.items()} |
| |
| with open('../data/bpe_ambiguity_tokens.json', 'r') as f: |
| ambig_data = json.load(f) |
| ambig_ids = set(ambig_data.get('ambiguous_ids', [])) |
| |
| |
| print("Loading model...") |
| checkpoint = torch.load('../checkpoints_v4_bpe_topology/best_v4_bpe_model.pt', map_location=device) |
| cfg = checkpoint['config']['model'] if 'config' in checkpoint else None |
| |
| if cfg: |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=cfg['sequence']['vocab_size'], |
| seq_hidden_size=cfg['sequence']['hidden_size'], |
| seq_num_layers=cfg['sequence']['num_hidden_layers'], |
| seq_num_heads=cfg['sequence']['num_attention_heads'], |
| seq_max_length=cfg['sequence']['max_length'], |
| ms_vocab_size=cfg['mass_spectrometry']['vocab_size'], |
| ms_hidden_size=cfg['mass_spectrometry']['hidden_size'], |
| ms_num_layers=cfg['mass_spectrometry']['num_hidden_layers'], |
| ms_num_heads=cfg['mass_spectrometry']['num_attention_heads'], |
| ms_max_length=cfg['mass_spectrometry']['max_length'], |
| struct_vocab_size=cfg['structure_3d']['vocab_size'], |
| struct_hidden_size=cfg['structure_3d']['hidden_size'], |
| struct_num_layers=cfg['structure_3d']['num_hidden_layers'], |
| struct_num_heads=cfg['structure_3d']['num_attention_heads'], |
| struct_max_length=cfg['structure_3d']['max_length'], |
| fusion_hidden_size=cfg['fusion']['fusion_hidden_size'], |
| fusion_num_layers=cfg['fusion']['fusion_num_layers'], |
| ) |
| model = MultimodalGlycanBERT(config) |
| model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint) |
| model.to(device) |
| model.eval() |
| |
| |
| print("Loading sequences...") |
| with open('../data/sequences_bpe.pkl', 'rb') as f: |
| seqs = pickle.load(f) |
| if isinstance(seqs, dict): |
| seqs = list(seqs.values()) |
| |
| |
| import random |
| random.seed(42) |
| sample = random.sample([s for s in seqs if any(t in ambig_ids for t in s.get('token_ids', []))], min(10000, len(seqs))) |
| |
| print(f"Analyzing {len(sample)} sequences...") |
| |
| all_confidences = [] |
| max_length = 256 |
| pad_id = token_to_id.get('[PAD]', 0) |
| mask_id = token_to_id.get('[MASK]', 4) |
| |
| with torch.no_grad(): |
| for seq in tqdm(sample): |
| token_ids = list(seq.get('token_ids', []))[:max_length] |
| ambig_positions = [i for i, t in enumerate(token_ids) if t in ambig_ids] |
| |
| if not ambig_positions: |
| continue |
| |
| |
| attention_mask = [1] * len(token_ids) + [0] * (max_length - len(token_ids)) |
| token_ids = token_ids + [pad_id] * (max_length - len(token_ids)) |
| |
| |
| masked = token_ids.copy() |
| for pos in ambig_positions: |
| masked[pos] = mask_id |
| |
| |
| seq_t = torch.tensor([masked], dtype=torch.long, device=device) |
| att_t = torch.tensor([attention_mask], dtype=torch.long, device=device) |
| res_t = torch.zeros_like(seq_t) |
| |
| outputs = model( |
| seq_token_ids=seq_t, seq_attention_mask=att_t, seq_residue_ids=res_t, |
| ms_token_ids=torch.zeros(1, 150, dtype=torch.long, device=device), |
| ms_attention_mask=torch.zeros(1, 150, dtype=torch.long, device=device), |
| has_ms=torch.zeros(1, dtype=torch.bool, device=device), |
| struct_token_ids=torch.zeros(1, 200, dtype=torch.long, device=device), |
| struct_attention_mask=torch.zeros(1, 200, dtype=torch.long, device=device), |
| struct_residue_ids=torch.full((1, 200), -1, dtype=torch.long, device=device), |
| has_3d=torch.zeros(1, dtype=torch.bool, device=device), |
| return_dict=True, |
| ) |
| |
| logits = outputs['seq_logits'][0] |
| for pos in ambig_positions: |
| probs = F.softmax(logits[pos], dim=-1) |
| conf = probs.max().item() |
| pred = probs.argmax().item() |
| all_confidences.append({ |
| 'confidence': conf, |
| 'original': token_ids[pos], |
| 'predicted': pred, |
| 'pred_is_valid': pred not in ambig_ids |
| }) |
| |
| |
| confs = [c['confidence'] for c in all_confidences] |
| valid_confs = [c['confidence'] for c in all_confidences if c['pred_is_valid']] |
| |
| print("\n" + "="*60) |
| print("CONFIDENCE DISTRIBUTION ANALYSIS") |
| print("="*60) |
| print(f"Total ambiguous tokens analyzed: {len(confs)}") |
| print(f"Predictions to valid tokens: {len(valid_confs)} ({100*len(valid_confs)/len(confs):.1f}%)") |
| |
| |
| bins = [0.0, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 1.0] |
| print("\nConfidence Distribution (valid predictions only):") |
| for i in range(len(bins)-1): |
| cnt = sum(1 for c in valid_confs if bins[i] <= c < bins[i+1]) |
| pct = 100 * cnt / len(valid_confs) if valid_confs else 0 |
| bar = "█" * int(pct/2) |
| print(f" [{bins[i]:.2f}-{bins[i+1]:.2f}): {cnt:6d} ({pct:5.1f}%) {bar}") |
| |
| |
| print("\nCumulative Resolution by Threshold:") |
| for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]: |
| above = sum(1 for c in valid_confs if c >= thresh) |
| pct = 100 * above / len(confs) if confs else 0 |
| print(f" >= {thresh}: {above:6d} tokens ({pct:5.1f}% of total)") |
| |
| print("="*60) |
|
|
| if __name__ == "__main__": |
| main() |
|
|