ablang2_seq_restore / adapter.py
hemantn's picture
deoloyment file added
0e2f128
from ablang2.pretrained_utils.restoration import AbRestore
from ablang2.pretrained_utils.encodings import AbEncoding
from ablang2.pretrained_utils.alignment import AbAlignment
from ablang2.pretrained_utils.scores import AbScores
import torch
import numpy as np
from ablang2.pretrained_utils.extra_utils import res_to_seq, res_to_list
class HuggingFaceTokenizerAdapter:
def __init__(self, tokenizer, device):
self.tokenizer = tokenizer
self.device = device
self.pad_token_id = tokenizer.pad_token_id
self.mask_token_id = getattr(tokenizer, 'mask_token_id', None) or tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
self.vocab = tokenizer.get_vocab() if hasattr(tokenizer, 'get_vocab') else tokenizer.vocab
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.all_special_tokens = tokenizer.all_special_tokens
def __call__(self, seqs, pad=True, w_extra_tkns=False, device=None, mode=None):
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
input_ids = tokens['input_ids'].to(self.device if device is None else device)
if mode == 'decode':
# seqs is a tensor of token ids
if isinstance(seqs, torch.Tensor):
seqs = seqs.cpu().numpy()
decoded = []
for i, seq in enumerate(seqs):
chars = [self.inv_vocab.get(int(t), '') for t in seq if self.inv_vocab.get(int(t), '') not in {'-', '*', '<', '>'} and self.inv_vocab.get(int(t), '') != '']
# Use res_to_seq for formatting, pass (sequence, length) tuple as in original code
# The length is not always available, so use len(chars) as fallback
formatted = res_to_seq([ ''.join(chars), len(chars) ], mode='restore')
decoded.append(formatted)
return decoded
return input_ids
class HFAbRestore(AbRestore):
def __init__(self, hf_model, hf_tokenizer, spread=11, device='cpu', ncpu=1):
super().__init__(spread=spread, device=device, ncpu=ncpu)
self.used_device = device
self._hf_model = hf_model
self.tokenizer = HuggingFaceTokenizerAdapter(hf_tokenizer, device)
@property
def AbLang(self):
def model_call(x):
output = self._hf_model(x)
if hasattr(output, 'last_hidden_state'):
return output.last_hidden_state
return output
return model_call
def add_angle_brackets(seq):
# Assumes input is 'VH|VL' or 'VH|' or '|VL'
if '|' in seq:
vh, vl = seq.split('|', 1)
else:
vh, vl = seq, ''
return f"<{vh}>|<{vl}>"
class AbLang2PairedHuggingFaceAdapter(AbEncoding, AbRestore, AbAlignment, AbScores):
"""
Adapter to use pretrained utilities with a HuggingFace-loaded ablang2_paired model and tokenizer.
Automatically uses CUDA if available, otherwise CPU.
"""
def __init__(self, model, tokenizer, device=None, ncpu=1):
super().__init__()
if device is None:
self.used_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.used_device = torch.device(device)
self.AbLang = model # HuggingFace model instance
self.tokenizer = tokenizer
self.AbLang.to(self.used_device)
self.AbLang.eval()
# Always get AbRep from the underlying model
if hasattr(self.AbLang, 'model') and hasattr(self.AbLang.model, 'AbRep'):
self.AbRep = self.AbLang.model.AbRep
else:
raise AttributeError("Could not find AbRep in the HuggingFace model or its underlying model.")
self.ncpu = ncpu
self.spread = 11 # For compatibility with original utilities
# The following is no longer needed since all_special_tokens now returns IDs directly
# self.tokenizer.all_special_token_ids = [
# self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens
# ]
# self.tokenizer._all_special_tokens_str = self.tokenizer.all_special_tokens
# self.tokenizer.all_special_tokens = [
# self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer._all_special_tokens_str
# ]
def freeze(self):
self.AbLang.eval()
def unfreeze(self):
self.AbLang.train()
def _encode_sequences(self, seqs):
# Use HuggingFace-style padding and return PyTorch tensors
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
tokens = extract_input_ids(tokens, self.used_device)
return self.AbRep(tokens).last_hidden_states.detach()
def _predict_logits(self, seqs):
tokens = self.tokenizer(seqs, padding=True, return_tensors='pt')
tokens = extract_input_ids(tokens, self.used_device)
output = self.AbLang(tokens)
if hasattr(output, 'last_hidden_state'):
return output.last_hidden_state.detach()
return output.detach()
def _preprocess_labels(self, labels):
labels = extract_input_ids(labels, self.used_device)
return labels
def __call__(self, seqs, mode='seqcoding', align=False, stepwise_masking=False, fragmented=False, batch_size=50):
"""
Use different modes for different usecases, mimicking the original pretrained class.
"""
from ablang2.pretrained import format_seq_input
valid_modes = [
'rescoding', 'seqcoding', 'restore', 'likelihood', 'probability',
'pseudo_log_likelihood', 'confidence'
]
if mode not in valid_modes:
raise SyntaxError(f"Given mode doesn't exist. Please select one of the following: {valid_modes}.")
seqs, chain = format_seq_input(seqs, fragmented=fragmented)
if align:
numbered_seqs, seqs, number_alignment = self.number_sequences(
seqs, chain=chain, fragmented=fragmented
)
else:
numbered_seqs = None
number_alignment = None
subset_list = []
for subset in [seqs[x:x+batch_size] for x in range(0, len(seqs), batch_size)]:
subset_list.append(getattr(self, mode)(subset, align=align, stepwise_masking=stepwise_masking))
return self.reformat_subsets(
subset_list,
mode=mode,
align=align,
numbered_seqs=numbered_seqs,
seqs=seqs,
number_alignment=number_alignment,
)
def pseudo_log_likelihood(self, seqs, **kwargs):
"""
Original (non-vectorized) pseudo log-likelihood computation matching notebook behavior.
"""
# Format input: join VH and VL with '|'
formatted_seqs = []
for s in seqs:
if isinstance(s, (list, tuple)):
formatted_seqs.append('|'.join(s))
else:
formatted_seqs.append(s)
# Tokenize all sequences in batch
labels = self.tokenizer(
formatted_seqs, padding=True, return_tensors='pt'
)
labels = extract_input_ids(labels, self.used_device)
# Convert special tokens to IDs
if isinstance(self.tokenizer.all_special_tokens[0], int):
special_token_ids = set(self.tokenizer.all_special_tokens)
else:
special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
pad_token_id = self.tokenizer.pad_token_id
mask_token_id = getattr(self.tokenizer, 'mask_token_id', None)
if mask_token_id is None:
mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
plls = []
with torch.no_grad():
for i, seq_label in enumerate(labels):
seq_pll = []
for j, token_id in enumerate(seq_label):
if token_id.item() in special_token_ids or token_id.item() == pad_token_id:
continue
masked = seq_label.clone()
masked[j] = mask_token_id
logits = self.AbLang(masked.unsqueeze(0))
if hasattr(logits, 'last_hidden_state'):
logits = logits.last_hidden_state
logits = logits[0, j]
nll = torch.nn.functional.cross_entropy(
logits.unsqueeze(0), token_id.unsqueeze(0), reduction="none"
)
seq_pll.append(-nll.item())
if seq_pll:
plls.append(np.mean(seq_pll))
else:
plls.append(float('nan'))
return np.array(plls)
def confidence(self, seqs, **kwargs):
"""Confidence calculation - match original ablang2 implementation by excluding all special tokens from loss."""
# Format input: join VH and VL with '|'
formatted_seqs = []
for s in seqs:
if isinstance(s, (list, tuple)):
formatted_seqs.append('|'.join(s))
else:
formatted_seqs.append(s)
plls = []
for seq in formatted_seqs:
tokens = self.tokenizer([seq], padding=True, return_tensors='pt')
input_ids = extract_input_ids(tokens, self.used_device)
with torch.no_grad():
output = self.AbLang(input_ids)
if hasattr(output, 'last_hidden_state'):
logits = output.last_hidden_state
else:
logits = output
# Get the sequence (remove batch dimension)
logits = logits[0] # [seq_len, vocab_size]
input_ids = input_ids[0] # [seq_len]
# Exclude all special tokens (pad, mask, etc.)
if isinstance(self.tokenizer.all_special_tokens[0], int):
special_token_ids = set(self.tokenizer.all_special_tokens)
else:
special_token_ids = set(self.tokenizer.convert_tokens_to_ids(tok) for tok in self.tokenizer.all_special_tokens)
valid_mask = ~torch.isin(input_ids, torch.tensor(list(special_token_ids), device=input_ids.device))
if valid_mask.sum() > 0:
valid_logits = logits[valid_mask]
valid_labels = input_ids[valid_mask]
# Calculate cross-entropy loss
nll = torch.nn.functional.cross_entropy(
valid_logits,
valid_labels,
reduction="mean"
)
pll = -nll.item()
else:
pll = 0.0
plls.append(pll)
return np.array(plls, dtype=np.float32)
def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
"""
Probability of mutations - applies softmax to logits to get probabilities
"""
# Format input: join VH and VL with '|'
formatted_seqs = []
for s in seqs:
if isinstance(s, (list, tuple)):
formatted_seqs.append('|'.join(s))
else:
formatted_seqs.append(s)
# Get logits
if stepwise_masking:
# For stepwise masking, we need to implement it similar to likelihood
# This is a simplified version - you might want to implement full stepwise masking
logits = self._predict_logits(formatted_seqs)
else:
logits = self._predict_logits(formatted_seqs)
# Apply softmax to get probabilities
probs = logits.softmax(-1).cpu().numpy()
if align:
return probs
else:
# Return residue-level probabilities (excluding special tokens)
return [res_to_list(state, seq) for state, seq in zip(probs, formatted_seqs)]
def restore(self, seqs, align=False, **kwargs):
hf_abrestore = HFAbRestore(self.AbLang, self.tokenizer, spread=self.spread, device=self.used_device, ncpu=self.ncpu)
restored = hf_abrestore.restore(seqs, align=align)
# Apply angle brackets formatting
if isinstance(restored, np.ndarray):
restored = np.array([add_angle_brackets(seq) for seq in restored])
else:
restored = [add_angle_brackets(seq) for seq in restored]
return restored
def extract_input_ids(tokens, device):
if hasattr(tokens, 'input_ids'):
return tokens.input_ids.to(device)
elif isinstance(tokens, dict):
if 'input_ids' in tokens:
return tokens['input_ids'].to(device)
else:
for v in tokens.values():
if hasattr(v, 'ndim') or torch.is_tensor(v):
return v.to(device)
elif torch.is_tensor(tokens):
return tokens.to(device)
else:
raise ValueError("Could not extract input_ids from tokenizer output")