Spaces:
Sleeping
Sleeping
| from ablang2.pretrained_utils.restoration import AbRestore | |
| from ablang2.pretrained_utils.encodings import AbEncoding | |
| from ablang2.pretrained_utils.alignment import AbAlignment | |
| from ablang2.pretrained_utils.scores import AbScores | |
| import torch | |
| import numpy as np | |
| from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list | |
| class HuggingFaceTokenizerAdapter: | |
| def __init__(self, tokenizer, device): | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| self.pad_token_id = tokenizer.pad_token_id | |
| self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token) | |
| self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab | |
| self.inv_vocab = {v: k for k, v in self.vocab.items()} | |
| self.all_special_tokens = tokenizer.all_special_tokens | |
| def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None): | |
| tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') | |
| input_ids = tokens['input_ids'].to(self.device if device is None else device) | |
| if mode == 'decode': | |
| # seqs is a tensor of token ids | |
| if isinstance(seqs, torch.Tensor): | |
| seqs = seqs.cpu().numpy() | |
| decoded = [] | |
| for i, seq in enumerate(seqs): | |
| chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != ''] | |
| # Use res_to_seq for formatting, pass (sequence, length) tuple as in original code | |
| # The length is not always available, so use len(chars) as fallback | |
| formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore') | |
| decoded.append(formatted) | |
| return decoded | |
| return input_ids | |
| class HFAbRestore(AbRestore): | |
| def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1): | |
| super().__init__(spread=spread, device=device, ncpu=ncpu) | |
| self.used_device = device | |
| self._hf_model = hf_model | |
| self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device) | |
| def AbLang(self): | |
| def model_call(x): | |
| output = self._hf_model(x) | |
| if hasattr(output, 'last_hidden_state'): | |
| return output.last_hidden_state | |
| return output | |
| return model_call | |
| def add_angle_brackets(seq): | |
| # Assumes input is 'VH|VL' or 'VH|' or '|VL' | |
| if '|' in seq: | |
| vh, vl = seq.split('|', 1) | |
| else: | |
| vh, vl = seq, '' | |
| return f"<{vh}>|<{vl}>" | |
| class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores): | |
| """ | |
| Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer. | |
| Automatically uses CUDA if available, otherwise CPU. | |
| """ | |
| def __init__(self, model, tokenizer, device=None, ncpu=1): | |
| super().__init__() | |
| if device is None: | |
| self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| self.used_device = torch.device(device) | |
| self.AbLang = model # HuggingFace model instance | |
| self.tokenizer = tokenizer | |
| self.AbLang.to(self.used_device) | |
| self.AbLang.eval() | |
| # Always get AbRep from the underlying model | |
| if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'): | |
| self.AbRep = self.AbLang.model.AbRep | |
| else: | |
| raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.") | |
| self.ncpu = ncpu | |
| self.spread = 11 # For compatibility with original utilities | |
| # The following is no longer needed since all_special_tokens now returns IDs directly | |
| # self.tokenizer.all_special_token_ids = [ | |
| # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens | |
| # ] | |
| # self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens | |
| # self.tokenizer.all_special_tokens = [ | |
| # self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str | |
| # ] | |
| def freeze(self): | |
| self.AbLang.eval() | |
| def unfreeze(self): | |
| self.AbLang.train() | |
| def _encode_sequences(self, seqs): | |
| # Use HuggingFace-style padding and return PyTorch tensors | |
| tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') | |
| tokens = extract_input_ids(tokens, self.used_device) | |
| return self.AbRep(tokens).last_hidden_states.detach() | |
| def _predict_logits(self, seqs): | |
| tokens = self.tokenizer(seqs, padding=True, return_tensors='pt') | |
| tokens = extract_input_ids(tokens, self.used_device) | |
| output = self.AbLang(tokens) | |
| if hasattr(output, 'last_hidden_state'): | |
| return output.last_hidden_state.detach() | |
| return output.detach() | |
| def _preprocess_labels(self, labels): | |
| labels = extract_input_ids(labels, self.used_device) | |
| return labels | |
| def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50): | |
| """ | |
| Use different modes for different usecases, mimicking the original pretrained class. | |
| """ | |
| from ablang2.pretrained import format_seq_input | |
| valid_modes = [ | |
| 'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability', | |
| 'pseudo_log_likelihood', 'confidence' | |
| ] | |
| if mode not in valid_modes: | |
| raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.") | |
| seqs, chain = format_seq_input(seqs, fragmented=fragmented) | |
| if align: | |
| numbered_seqs, seqs, number_alignment = self.number_sequences( | |
| seqs, chain=chain, fragmented=fragmented | |
| ) | |
| else: | |
| numbered_seqs = None | |
| number_alignment = None | |
| subset_list = [] | |
| for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]: | |
| subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking)) | |
| return self.reformat_subsets( | |
| subset_list, | |
| mode=mode, | |
| align=align, | |
| numbered_seqs=numbered_seqs, | |
| seqs=seqs, | |
| number_alignment=number_alignment, | |
| ) | |
| def pseudo_log_likelihood(self, seqs, **kwargs): | |
| """ | |
| Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior. | |
| """ | |
| # Format input: join VH and VL with '|' | |
| formatted_seqs = [] | |
| for s in seqs: | |
| if isinstance(s, (list, tuple)): | |
| formatted_seqs.append('|'.join(s)) | |
| else: | |
| formatted_seqs.append(s) | |
| # Tokenize all sequences in batch | |
| labels = self.tokenizer( | |
| formatted_seqs, padding=True, return_tensors='pt' | |
| ) | |
| labels = extract_input_ids(labels, self.used_device) | |
| # Convert special tokens to IDs | |
| if isinstance(self.tokenizer.all_special_tokens[0], int): | |
| special_token_ids = set(self.tokenizer.all_special_tokens) | |
| else: | |
| special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) | |
| pad_token_id = self.tokenizer.pad_token_id | |
| mask_token_id = getattr(self.tokenizer, 'mask_token_id', None) | |
| if mask_token_id is None: | |
| mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) | |
| plls = [] | |
| with torch.no_grad(): | |
| for i, seq_label in enumerate(labels): | |
| seq_pll = [] | |
| for j, token_id in enumerate(seq_label): | |
| if token_id.item() in special_token_ids or token_id.item() == pad_token_id: | |
| continue | |
| masked = seq_label.clone() | |
| masked[j] = mask_token_id | |
| logits = self.AbLang(masked.unsqueeze(0)) | |
| if hasattr(logits, 'last_hidden_state'): | |
| logits = logits.last_hidden_state | |
| logits = logits[0, j] | |
| nll = torch.nn.functional.cross_entropy( | |
| logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none" | |
| ) | |
| seq_pll.append(-nll.item()) | |
| if seq_pll: | |
| plls.append(np.mean(seq_pll)) | |
| else: | |
| plls.append(float('nan')) | |
| return np.array(plls) | |
| def confidence(self, seqs, **kwargs): | |
| """Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss.""" | |
| # Format input: join VH and VL with '|' | |
| formatted_seqs = [] | |
| for s in seqs: | |
| if isinstance(s, (list, tuple)): | |
| formatted_seqs.append('|'.join(s)) | |
| else: | |
| formatted_seqs.append(s) | |
| plls = [] | |
| for seq in formatted_seqs: | |
| tokens = self.tokenizer([seq], padding=True, return_tensors='pt') | |
| input_ids = extract_input_ids(tokens, self.used_device) | |
| with torch.no_grad(): | |
| output = self.AbLang(input_ids) | |
| if hasattr(output, 'last_hidden_state'): | |
| logits = output.last_hidden_state | |
| else: | |
| logits = output | |
| # Get the sequence (remove batch dimension) | |
| logits = logits[0] # [seq_len, vocab_size] | |
| input_ids = input_ids[0] # [seq_len] | |
| # Exclude all special tokens (pad, mask, etc.) | |
| if isinstance(self.tokenizer.all_special_tokens[0], int): | |
| special_token_ids = set(self.tokenizer.all_special_tokens) | |
| else: | |
| special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens) | |
| valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device)) | |
| if valid_mask.sum() > 0: | |
| valid_logits = logits[valid_mask] | |
| valid_labels = input_ids[valid_mask] | |
| # Calculate cross-entropy loss | |
| nll = torch.nn.functional.cross_entropy( | |
| valid_logits, | |
| valid_labels, | |
| reduction="mean" | |
| ) | |
| pll = -nll.item() | |
| else: | |
| pll = 0.0 | |
| plls.append(pll) | |
| return np.array(plls, dtype=np.float32) | |
| def probability(self, seqs, align=False, stepwise_masking=False, **kwargs): | |
| """ | |
| Probability of mutations - applies softmax to logits to get probabilities | |
| """ | |
| # Format input: join VH and VL with '|' | |
| formatted_seqs = [] | |
| for s in seqs: | |
| if isinstance(s, (list, tuple)): | |
| formatted_seqs.append('|'.join(s)) | |
| else: | |
| formatted_seqs.append(s) | |
| # Get logits | |
| if stepwise_masking: | |
| # For stepwise masking, we need to implement it similar to likelihood | |
| # This is a simplified version - you might want to implement full stepwise masking | |
| logits = self._predict_logits(formatted_seqs) | |
| else: | |
| logits = self._predict_logits(formatted_seqs) | |
| # Apply softmax to get probabilities | |
| probs = logits.softmax(-1).cpu().numpy() | |
| if align: | |
| return probs | |
| else: | |
| # Return residue-level probabilities (excluding special tokens) | |
| return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)] | |
| def restore(self, seqs, align=False, **kwargs): | |
| hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu) | |
| restored = hf_abrestore.restore(seqs, align=align) | |
| # Apply angle brackets formatting | |
| if isinstance(restored, np.ndarray): | |
| restored = np.array([add_angle_brackets(seq) for seq in restored]) | |
| else: | |
| restored = [add_angle_brackets(seq) for seq in restored] | |
| return restored | |
| def extract_input_ids(tokens, device): | |
| if hasattr(tokens, 'input_ids'): | |
| return tokens.input_ids.to(device) | |
| elif isinstance(tokens, dict): | |
| if 'input_ids' in tokens: | |
| return tokens['input_ids'].to(device) | |
| else: | |
| for v in tokens.values(): | |
| if hasattr(v, 'ndim') or torch.is_tensor(v): | |
| return v.to(device) | |
| elif torch.is_tensor(tokens): | |
| return tokens.to(device) | |
| else: | |
| raise ValueError("Could not extract input_ids from tokenizer output") |