from ablang2.pretrained_utils.restoration import AbRestore from ablang2.pretrained_utils.encodings import AbEncoding from ablang2.pretrained_utils.alignment import AbAlignment from ablang2.pretrained_utils.scores import AbScores import torch import numpy as np from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list class HuggingFaceTokenizerAdapter: def __init__(self, tokenizer, device): self.tokenizer = tokenizer self.device = device self.pad_token_id = tokenizer.pad_token_id self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab self.inv_vocab = {v: k for k, v in self.vocab.items()} self.all_special_tokens = tokenizer.all_special_tokens def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device if device is None else device) if mode == 'decode': # seqs is a tensor of token ids if isinstance(seqs, torch.Tensor): seqs = seqs.cpu().numpy() decoded = [] for i, seq in enumerate(seqs): chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code # The length is not always available, so use len(chars) as fallback formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') decoded.append(formatted) return decoded return input_ids class HFAbRestore(AbRestore): def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): super().__init__(spread=spread, device=device, ncpu=ncpu) self.used_device = device self._hf_model = hf_model self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) @property def AbLang(self): def model_call(x): output = self._hf_model(x) if hasattr(output, 'last_hidden_state'): return output.last_hidden_state return output return model_call def add_angle_brackets(seq): # Assumes input is 'VH|VL' or 'VH|' or '|VL' if '|' in seq: vh, vl = seq.split('|', 1) else: vh, vl = seq, '' return f"<{vh}>|<{vl}>" class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): """ Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. Automatically uses CUDA if available, otherwise CPU. """ def __init__(self, model, tokenizer, device=None, ncpu=1): super().__init__() if device is None: self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.used_device = torch.device(device) self.AbLang = model # HuggingFace model instance self.tokenizer = tokenizer self.AbLang.to(self.used_device) self.AbLang.eval() # Always get AbRep from the underlying model if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): self.AbRep = self.AbLang.model.AbRep else: raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") self.ncpu = ncpu self.spread = 11 # For compatibility with original utilities # The following is no longer needed since all_special_tokens now returns IDs directly # self.tokenizer.all_special_token_ids = [ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens # ] # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens # self.tokenizer.all_special_tokens = [ # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str # ] def freeze(self): self.AbLang.eval() def unfreeze(self): self.AbLang.train() def _encode_sequences(self, seqs): # Use HuggingFace-style padding and return PyTorch tensors tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') tokens = extract_input_ids(tokens, self.used_device) return self.AbRep(tokens).last_hidden_states.detach() def _predict_logits(self, seqs): tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') tokens = extract_input_ids(tokens, self.used_device) output = self.AbLang(tokens) if hasattr(output, 'last_hidden_state'): return output.last_hidden_state.detach() return output.detach() def _preprocess_labels(self, labels): labels = extract_input_ids(labels, self.used_device) return labels def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): """ Use different modes for different usecases, mimicking the original pretrained class. """ from ablang2.pretrained import format_seq_input valid_modes = [ 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', 'pseudo_log_likelihood', 'confidence' ] if mode not in valid_modes: raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") seqs, chain = format_seq_input(seqs, fragmented=fragmented) if align: numbered_seqs, seqs, number_alignment = self.number_sequences( seqs, chain=chain, fragmented=fragmented ) else: numbered_seqs = None number_alignment = None subset_list = [] for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) return self.reformat_subsets( subset_list, mode=mode, align=align, numbered_seqs=numbered_seqs, seqs=seqs, number_alignment=number_alignment, ) def pseudo_log_likelihood(self, seqs, **kwargs): """ Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. """ # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Tokenize all sequences in batch labels = self.tokenizer( formatted_seqs, padding=True, return_tensors='pt' ) labels = extract_input_ids(labels, self.used_device) # Convert special tokens to IDs if isinstance(self.tokenizer.all_special_tokens[0], int): special_token_ids = set(self.tokenizer.all_special_tokens) else: special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) pad_token_id = self.tokenizer.pad_token_id mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) if mask_token_id is None: mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) plls = [] with torch.no_grad(): for i, seq_label in enumerate(labels): seq_pll = [] for j, token_id in enumerate(seq_label): if token_id.item() in special_token_ids or token_id.item() == pad_token_id: continue masked = seq_label.clone() masked[j] = mask_token_id logits = self.AbLang(masked.unsqueeze(0)) if hasattr(logits, 'last_hidden_state'): logits = logits.last_hidden_state logits = logits[0, j] nll = torch.nn.functional.cross_entropy( logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" ) seq_pll.append(-nll.item()) if seq_pll: plls.append(np.mean(seq_pll)) else: plls.append(float('nan')) return np.array(plls) def confidence(self, seqs, **kwargs): """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) plls = [] for seq in formatted_seqs: tokens = self.tokenizer([seq], padding=True, return_tensors='pt') input_ids = extract_input_ids(tokens, self.used_device) with torch.no_grad(): output = self.AbLang(input_ids) if hasattr(output, 'last_hidden_state'): logits = output.last_hidden_state else: logits = output # Get the sequence (remove batch dimension) logits = logits[0] # [seq_len, vocab_size] input_ids = input_ids[0] # [seq_len] # Exclude all special tokens (pad, mask, etc.) if isinstance(self.tokenizer.all_special_tokens[0], int): special_token_ids = set(self.tokenizer.all_special_tokens) else: special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) if valid_mask.sum() > 0: valid_logits = logits[valid_mask] valid_labels = input_ids[valid_mask] # Calculate cross-entropy loss nll = torch.nn.functional.cross_entropy( valid_logits, valid_labels, reduction="mean" ) pll = -nll.item() else: pll = 0.0 plls.append(pll) return np.array(plls, dtype=np.float32) def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): """ Probability of mutations - applies softmax to logits to get probabilities """ # Format input: join VH and VL with '|' formatted_seqs = [] for s in seqs: if isinstance(s, (list, tuple)): formatted_seqs.append('|'.join(s)) else: formatted_seqs.append(s) # Get logits if stepwise_masking: # For stepwise masking, we need to implement it similar to likelihood # This is a simplified version - you might want to implement full stepwise masking logits = self._predict_logits(formatted_seqs) else: logits = self._predict_logits(formatted_seqs) # Apply softmax to get probabilities probs = logits.softmax(-1).cpu().numpy() if align: return probs else: # Return residue-level probabilities (excluding special tokens) return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] def restore(self, seqs, align=False, **kwargs): hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) restored = hf_abrestore.restore(seqs, align=align) # Apply angle brackets formatting if isinstance(restored, np.ndarray): restored = np.array([add_angle_brackets(seq) for seq in restored]) else: restored = [add_angle_brackets(seq) for seq in restored] return restored def extract_input_ids(tokens, device): if hasattr(tokens, 'input_ids'): return tokens.input_ids.to(device) elif isinstance(tokens, dict): if 'input_ids' in tokens: return tokens['input_ids'].to(device) else: for v in tokens.values(): if hasattr(v, 'ndim') or torch.is_tensor(v): return v.to(device) elif torch.is_tensor(tokens): return tokens.to(device) else: raise ValueError("Could not extract input_ids from tokenizer output")