| |
| """ |
| IPA Self-Distillation for BPE Tokenized Glycan Sequences |
| Iterative Pseudo-Annotation (IPA) to resolve ambiguous BPE tokens. |
| """ |
|
|
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from pathlib import Path |
| from typing import List, Dict, Set, Tuple |
| import logging |
| import pickle |
| import json |
| import argparse |
| from copy import deepcopy |
| from tqdm import tqdm |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BPEIPADistiller: |
| """IPA Self-Distillation for BPE tokenized sequences.""" |
| |
| def __init__( |
| self, |
| checkpoint_path: str, |
| vocab_path: str, |
| ambiguity_path: str, |
| device: str = "cuda", |
| threshold: float = 0.8, |
| batch_size: int = 32, |
| ): |
| self.checkpoint_path = Path(checkpoint_path) |
| self.vocab_path = Path(vocab_path) |
| self.device = torch.device(device) |
| self.threshold = threshold |
| self.batch_size = batch_size |
| |
| |
| logger.info(f"Loading BPE vocabulary from {vocab_path}") |
| with open(vocab_path, 'r') as f: |
| vocab = json.load(f) |
| |
| self.vocab = vocab |
| self.token_to_id = vocab.get('token_to_id', vocab) |
| self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
| self.vocab_size = len(self.token_to_id) |
| |
| |
| self.pad_id = self.token_to_id.get('[PAD]', 0) |
| self.mask_id = self.token_to_id.get('[MASK]', 4) |
| self.start_id = self.token_to_id.get('[START]', 2) |
| self.end_id = self.token_to_id.get('[END]', 3) |
| |
| logger.info(f"Vocabulary size: {self.vocab_size}") |
| logger.info(f"MASK token ID: {self.mask_id}") |
| |
| |
| logger.info(f"Loading ambiguity data from {ambiguity_path}") |
| with open(ambiguity_path, 'r') as f: |
| ambig_data = json.load(f) |
| |
| self.ambiguous_token_ids = set(ambig_data.get('ambiguous_ids', [])) |
| self.ambiguous_tokens = ambig_data.get('ambiguous_tokens', {}) |
| |
| logger.info(f"Ambiguous BPE tokens: {len(self.ambiguous_token_ids)}") |
| |
| |
| logger.info(f"Loading model from {checkpoint_path}") |
| self.model = self._load_model() |
| self.model.to(self.device) |
| self.model.eval() |
| |
| logger.info(f"BPE IPA Distiller initialized (threshold={threshold})") |
| |
| def _load_model(self) -> MultimodalGlycanBERT: |
| """Load model from checkpoint.""" |
| checkpoint = torch.load(self.checkpoint_path, map_location=self.device) |
| |
| if 'config' in checkpoint: |
| cfg = checkpoint['config']['model'] |
| else: |
| cfg = { |
| 'sequence': {'vocab_size': 2200, 'hidden_size': 768, 'num_hidden_layers': 12, |
| 'num_attention_heads': 12, 'max_length': 256}, |
| 'mass_spectrometry': {'vocab_size': 242, 'hidden_size': 384, |
| 'num_hidden_layers': 6, 'num_attention_heads': 6, 'max_length': 150}, |
| 'structure_3d': {'vocab_size': 1024, 'hidden_size': 512, |
| 'num_hidden_layers': 8, 'num_attention_heads': 8, 'max_length': 200}, |
| 'fusion': {'fusion_hidden_size': 768, 'fusion_num_layers': 2}, |
| } |
| |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=cfg['sequence']['vocab_size'], |
| seq_hidden_size=cfg['sequence']['hidden_size'], |
| seq_num_layers=cfg['sequence']['num_hidden_layers'], |
| seq_num_heads=cfg['sequence']['num_attention_heads'], |
| seq_max_length=cfg['sequence']['max_length'], |
| ms_vocab_size=cfg['mass_spectrometry']['vocab_size'], |
| ms_hidden_size=cfg['mass_spectrometry']['hidden_size'], |
| ms_num_layers=cfg['mass_spectrometry']['num_hidden_layers'], |
| ms_num_heads=cfg['mass_spectrometry']['num_attention_heads'], |
| ms_max_length=cfg['mass_spectrometry']['max_length'], |
| struct_vocab_size=cfg['structure_3d']['vocab_size'], |
| struct_hidden_size=cfg['structure_3d']['hidden_size'], |
| struct_num_layers=cfg['structure_3d']['num_hidden_layers'], |
| struct_num_heads=cfg['structure_3d']['num_attention_heads'], |
| struct_max_length=cfg['structure_3d']['max_length'], |
| fusion_hidden_size=cfg['fusion']['fusion_hidden_size'], |
| fusion_num_layers=cfg['fusion']['fusion_num_layers'], |
| ) |
| |
| model = MultimodalGlycanBERT(config) |
| |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| model.load_state_dict(checkpoint) |
| |
| return model |
| |
| def find_ambiguous_sequences(self, sequences: List[Dict]) -> Tuple[List[Dict], List[Dict], Dict]: |
| """Separate sequences into ambiguous and clean.""" |
| ambiguous, clean = [], [] |
| total_ambig_tokens = 0 |
| |
| for seq in sequences: |
| token_ids = seq.get('token_ids', []) |
| ambig_count = sum(1 for tid in token_ids if tid in self.ambiguous_token_ids) |
| |
| if ambig_count > 0: |
| seq_copy = seq.copy() |
| seq_copy['_ambig_count'] = ambig_count |
| ambiguous.append(seq_copy) |
| total_ambig_tokens += ambig_count |
| else: |
| clean.append(seq) |
| |
| stats = { |
| 'total_sequences': len(sequences), |
| 'ambiguous_sequences': len(ambiguous), |
| 'clean_sequences': len(clean), |
| 'total_ambiguous_tokens': total_ambig_tokens, |
| 'ambiguous_percentage': 100 * len(ambiguous) / len(sequences) if sequences else 0, |
| } |
| |
| return ambiguous, clean, stats |
| |
| @torch.no_grad() |
| def resolve_batch(self, sequences: List[Dict], max_length: int = 256) -> Tuple[List[Dict], int, int]: |
| """Resolve ambiguous tokens in a batch.""" |
| if not sequences: |
| return [], 0, 0 |
| |
| batch_token_ids, batch_attention_masks, batch_ambig_positions = [], [], [] |
| |
| for seq in sequences: |
| token_ids = list(seq.get('token_ids', [])) |
| ambig_positions = [i for i, tid in enumerate(token_ids) if tid in self.ambiguous_token_ids] |
| batch_ambig_positions.append(ambig_positions) |
| |
| if len(token_ids) > max_length: |
| token_ids = token_ids[:max_length] |
| |
| attention_mask = [1] * len(token_ids) |
| pad_length = max_length - len(token_ids) |
| token_ids = token_ids + [self.pad_id] * pad_length |
| attention_mask = attention_mask + [0] * pad_length |
| |
| masked_ids = token_ids.copy() |
| for pos in ambig_positions: |
| if pos < max_length: |
| masked_ids[pos] = self.mask_id |
| |
| batch_token_ids.append(masked_ids) |
| batch_attention_masks.append(attention_mask) |
| |
| seq_tensor = torch.tensor(batch_token_ids, dtype=torch.long, device=self.device) |
| attention_tensor = torch.tensor(batch_attention_masks, dtype=torch.long, device=self.device) |
| residue_tensor = torch.zeros_like(seq_tensor) |
| |
| batch_size = seq_tensor.shape[0] |
| ms_len, struct_len = 150, 200 |
| |
| outputs = self.model( |
| seq_token_ids=seq_tensor, |
| seq_attention_mask=attention_tensor, |
| seq_residue_ids=residue_tensor, |
| ms_token_ids=torch.zeros(batch_size, ms_len, dtype=torch.long, device=self.device), |
| ms_attention_mask=torch.zeros(batch_size, ms_len, dtype=torch.long, device=self.device), |
| has_ms=torch.zeros(batch_size, dtype=torch.bool, device=self.device), |
| struct_token_ids=torch.zeros(batch_size, struct_len, dtype=torch.long, device=self.device), |
| struct_attention_mask=torch.zeros(batch_size, struct_len, dtype=torch.long, device=self.device), |
| struct_residue_ids=torch.full((batch_size, struct_len), -1, dtype=torch.long, device=self.device), |
| has_3d=torch.zeros(batch_size, dtype=torch.bool, device=self.device), |
| return_dict=True, |
| ) |
| |
| logits = outputs['seq_logits'] |
| |
| resolved_sequences = [] |
| total_resolved, total_remaining = 0, 0 |
| |
| for i, seq in enumerate(sequences): |
| resolved = deepcopy(seq) |
| token_ids = list(resolved['token_ids']) |
| ambig_positions = batch_ambig_positions[i] |
| n_resolved = 0 |
| |
| for pos in ambig_positions: |
| if pos >= max_length: |
| continue |
| |
| pos_logits = logits[i, pos] |
| probs = F.softmax(pos_logits, dim=-1) |
| top_prob, top_idx = probs.max(dim=-1) |
| confidence = top_prob.item() |
| pred_id = top_idx.item() |
| |
| if (confidence >= self.threshold and |
| pred_id not in self.ambiguous_token_ids and |
| pred_id != token_ids[pos]): |
| |
| original_token = self.id_to_token.get(token_ids[pos], '?') |
| new_token = self.id_to_token.get(pred_id, '?') |
| |
| token_ids[pos] = pred_id |
| n_resolved += 1 |
| |
| if 'resolved_positions' not in resolved: |
| resolved['resolved_positions'] = [] |
| resolved['resolved_positions'].append({ |
| 'pos': pos, 'original': original_token, |
| 'resolved': new_token, 'confidence': confidence, |
| }) |
| |
| resolved['token_ids'] = token_ids |
| if n_resolved > 0: |
| resolved['is_distilled'] = True |
| |
| resolved_sequences.append(resolved) |
| total_resolved += n_resolved |
| n_remaining = sum(1 for tid in token_ids if tid in self.ambiguous_token_ids) |
| total_remaining += n_remaining |
| |
| return resolved_sequences, total_resolved, total_remaining |
| |
| def run_ipa(self, sequences: List[Dict], iterations: int = 3) -> Tuple[List[Dict], Dict]: |
| """Run IPA self-distillation.""" |
| logger.info(f"Starting BPE IPA self-distillation (threshold={self.threshold}, iterations={iterations})") |
| |
| ambiguous, clean, initial_stats = self.find_ambiguous_sequences(sequences) |
| |
| logger.info(f"Initial: {initial_stats['ambiguous_sequences']} ambiguous sequences ({initial_stats['ambiguous_percentage']:.1f}%)") |
| logger.info(f"Total ambiguous tokens: {initial_stats['total_ambiguous_tokens']}") |
| |
| stats = {'initial': initial_stats, 'iterations': [], 'threshold': self.threshold} |
| current_sequences = ambiguous |
| |
| for iteration in range(iterations): |
| logger.info(f"\n=== Iteration {iteration + 1}/{iterations} ===") |
| |
| resolved_this_iter, total_remaining = 0, 0 |
| new_sequences = [] |
| |
| num_batches = (len(current_sequences) + self.batch_size - 1) // self.batch_size |
| |
| for batch_idx in tqdm(range(num_batches), desc=f"Iteration {iteration + 1}"): |
| start_idx = batch_idx * self.batch_size |
| end_idx = min(start_idx + self.batch_size, len(current_sequences)) |
| batch = current_sequences[start_idx:end_idx] |
| |
| resolved_batch, n_resolved, n_remaining = self.resolve_batch(batch) |
| new_sequences.extend(resolved_batch) |
| resolved_this_iter += n_resolved |
| total_remaining += n_remaining |
| |
| current_sequences = new_sequences |
| stats['iterations'].append({ |
| 'iteration': iteration + 1, |
| 'tokens_resolved': resolved_this_iter, |
| 'tokens_remaining': total_remaining, |
| }) |
| |
| logger.info(f" Resolved {resolved_this_iter} tokens, {total_remaining} remaining") |
| |
| if resolved_this_iter == 0: |
| logger.info(" No progress, stopping early") |
| break |
| |
| all_sequences = clean + current_sequences |
| _, still_ambiguous, final_stats = self.find_ambiguous_sequences(current_sequences) |
| |
| stats['final'] = { |
| 'total_sequences': len(all_sequences), |
| 'fully_resolved': initial_stats['ambiguous_sequences'] - final_stats['ambiguous_sequences'], |
| 'still_ambiguous': final_stats['ambiguous_sequences'], |
| 'remaining_ambiguous_tokens': final_stats['total_ambiguous_tokens'], |
| 'total_tokens_resolved': sum(it['tokens_resolved'] for it in stats['iterations']), |
| } |
| |
| logger.info(f"\n=== IPA Complete ===") |
| logger.info(f"Total tokens resolved: {stats['final']['total_tokens_resolved']}") |
| |
| return all_sequences, stats |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="BPE IPA Self-Distillation") |
| parser.add_argument("--checkpoint", type=str, required=True) |
| parser.add_argument("--sequences", type=str, required=True) |
| parser.add_argument("--vocab", type=str, required=True) |
| parser.add_argument("--ambiguity", type=str, required=True) |
| parser.add_argument("--output", type=str, required=True) |
| parser.add_argument("--threshold", type=float, default=0.8) |
| parser.add_argument("--iterations", type=int, default=3) |
| parser.add_argument("--batch-size", type=int, default=32) |
| parser.add_argument("--device", type=str, default="cuda") |
| |
| args = parser.parse_args() |
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| |
| logger.info(f"Loading sequences from {args.sequences}") |
| with open(args.sequences, 'rb') as f: |
| sequences = pickle.load(f) |
| |
| if isinstance(sequences, dict): |
| sequences = list(sequences.values()) |
| |
| logger.info(f"Loaded {len(sequences)} sequences") |
| |
| distiller = BPEIPADistiller( |
| checkpoint_path=args.checkpoint, |
| vocab_path=args.vocab, |
| ambiguity_path=args.ambiguity, |
| device=args.device, |
| threshold=args.threshold, |
| batch_size=args.batch_size, |
| ) |
| |
| expanded_sequences, stats = distiller.run_ipa(sequences=sequences, iterations=args.iterations) |
| |
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(output_path, 'wb') as f: |
| pickle.dump(expanded_sequences, f) |
| |
| logger.info(f"Saved {len(expanded_sequences)} expanded sequences to {output_path}") |
| |
| stats_path = output_path.with_suffix('.stats.json') |
| with open(stats_path, 'w') as f: |
| json.dump(stats, f, indent=2) |
| |
| print("\n" + "="*60) |
| print("BPE IPA SELF-DISTILLATION SUMMARY") |
| print("="*60) |
| print(f"Input sequences: {stats['initial']['total_sequences']}") |
| print(f"Initially ambiguous: {stats['initial']['ambiguous_sequences']} ({stats['initial']['ambiguous_percentage']:.1f}%)") |
| print(f"Ambiguous tokens: {stats['initial']['total_ambiguous_tokens']}") |
| print(f"Threshold: {stats['threshold']}") |
| print(f"Iterations run: {len(stats['iterations'])}") |
| print(f"Tokens resolved: {stats['final']['total_tokens_resolved']}") |
| print(f"Still ambiguous: {stats['final']['still_ambiguous']}") |
| print("="*60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|