bertose-affinose-training-code / code /training /analyze_confidence.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
6.44 kB
#!/usr/bin/env python3
"""Analyze confidence distribution for ambiguous BPE tokens."""
import sys
import torch
import torch.nn.functional as F
from pathlib import Path
import pickle
import json
import numpy as np
from collections import Counter
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
def main():
device = torch.device("cuda")
# Load vocab and ambiguity
with open('../data/bpe_vocabulary_clean.json', 'r') as f:
vocab = json.load(f)
token_to_id = vocab.get('token_to_id', vocab)
id_to_token = {v: k for k, v in token_to_id.items()}
with open('../data/bpe_ambiguity_tokens.json', 'r') as f:
ambig_data = json.load(f)
ambig_ids = set(ambig_data.get('ambiguous_ids', []))
# Load model
print("Loading model...")
checkpoint = torch.load('../checkpoints_v4_bpe_topology/best_v4_bpe_model.pt', map_location=device)
cfg = checkpoint['config']['model'] if 'config' in checkpoint else None
if cfg:
config = MultimodalGlycanBERTConfig(
seq_vocab_size=cfg['sequence']['vocab_size'],
seq_hidden_size=cfg['sequence']['hidden_size'],
seq_num_layers=cfg['sequence']['num_hidden_layers'],
seq_num_heads=cfg['sequence']['num_attention_heads'],
seq_max_length=cfg['sequence']['max_length'],
ms_vocab_size=cfg['mass_spectrometry']['vocab_size'],
ms_hidden_size=cfg['mass_spectrometry']['hidden_size'],
ms_num_layers=cfg['mass_spectrometry']['num_hidden_layers'],
ms_num_heads=cfg['mass_spectrometry']['num_attention_heads'],
ms_max_length=cfg['mass_spectrometry']['max_length'],
struct_vocab_size=cfg['structure_3d']['vocab_size'],
struct_hidden_size=cfg['structure_3d']['hidden_size'],
struct_num_layers=cfg['structure_3d']['num_hidden_layers'],
struct_num_heads=cfg['structure_3d']['num_attention_heads'],
struct_max_length=cfg['structure_3d']['max_length'],
fusion_hidden_size=cfg['fusion']['fusion_hidden_size'],
fusion_num_layers=cfg['fusion']['fusion_num_layers'],
)
model = MultimodalGlycanBERT(config)
model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
model.to(device)
model.eval()
# Load sequences (sample)
print("Loading sequences...")
with open('../data/sequences_bpe.pkl', 'rb') as f:
seqs = pickle.load(f)
if isinstance(seqs, dict):
seqs = list(seqs.values())
# Sample 10k sequences for analysis
import random
random.seed(42)
sample = random.sample([s for s in seqs if any(t in ambig_ids for t in s.get('token_ids', []))], min(10000, len(seqs)))
print(f"Analyzing {len(sample)} sequences...")
all_confidences = []
max_length = 256
pad_id = token_to_id.get('[PAD]', 0)
mask_id = token_to_id.get('[MASK]', 4)
with torch.no_grad():
for seq in tqdm(sample):
token_ids = list(seq.get('token_ids', []))[:max_length]
ambig_positions = [i for i, t in enumerate(token_ids) if t in ambig_ids]
if not ambig_positions:
continue
# Pad
attention_mask = [1] * len(token_ids) + [0] * (max_length - len(token_ids))
token_ids = token_ids + [pad_id] * (max_length - len(token_ids))
# Mask ambiguous
masked = token_ids.copy()
for pos in ambig_positions:
masked[pos] = mask_id
# Forward
seq_t = torch.tensor([masked], dtype=torch.long, device=device)
att_t = torch.tensor([attention_mask], dtype=torch.long, device=device)
res_t = torch.zeros_like(seq_t)
outputs = model(
seq_token_ids=seq_t, seq_attention_mask=att_t, seq_residue_ids=res_t,
ms_token_ids=torch.zeros(1, 150, dtype=torch.long, device=device),
ms_attention_mask=torch.zeros(1, 150, dtype=torch.long, device=device),
has_ms=torch.zeros(1, dtype=torch.bool, device=device),
struct_token_ids=torch.zeros(1, 200, dtype=torch.long, device=device),
struct_attention_mask=torch.zeros(1, 200, dtype=torch.long, device=device),
struct_residue_ids=torch.full((1, 200), -1, dtype=torch.long, device=device),
has_3d=torch.zeros(1, dtype=torch.bool, device=device),
return_dict=True,
)
logits = outputs['seq_logits'][0]
for pos in ambig_positions:
probs = F.softmax(logits[pos], dim=-1)
conf = probs.max().item()
pred = probs.argmax().item()
all_confidences.append({
'confidence': conf,
'original': token_ids[pos],
'predicted': pred,
'pred_is_valid': pred not in ambig_ids
})
# Analyze
confs = [c['confidence'] for c in all_confidences]
valid_confs = [c['confidence'] for c in all_confidences if c['pred_is_valid']]
print("\n" + "="*60)
print("CONFIDENCE DISTRIBUTION ANALYSIS")
print("="*60)
print(f"Total ambiguous tokens analyzed: {len(confs)}")
print(f"Predictions to valid tokens: {len(valid_confs)} ({100*len(valid_confs)/len(confs):.1f}%)")
# Histogram
bins = [0.0, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 1.0]
print("\nConfidence Distribution (valid predictions only):")
for i in range(len(bins)-1):
cnt = sum(1 for c in valid_confs if bins[i] <= c < bins[i+1])
pct = 100 * cnt / len(valid_confs) if valid_confs else 0
bar = "█" * int(pct/2)
print(f" [{bins[i]:.2f}-{bins[i+1]:.2f}): {cnt:6d} ({pct:5.1f}%) {bar}")
# Cumulative
print("\nCumulative Resolution by Threshold:")
for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]:
above = sum(1 for c in valid_confs if c >= thresh)
pct = 100 * above / len(confs) if confs else 0
print(f" >= {thresh}: {above:6d} tokens ({pct:5.1f}% of total)")
print("="*60)
if __name__ == "__main__":
main()