kabudadada
Add esm folder and minimal app
e76b79a
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
logger = logging.getLogger(__name__)
from typing import Optional, Dict, List, Any, Union
import torch
from torch import nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from .tensor import is_1hot_tensor, add_eos_bos
from .masking import apply_mask, assert_valid_mask
def lm_marginal(mlm, x, mask):
"""
Utility to extract logprobs from an MLM, given an input tensor and a mask.
NOTE:
- **Mask each set bit in mask's L-dim separately.** - (as opposed to all at once.)
- Therefore, performs B*n forward passes.
- In the future, we might try just using a single forward pass per sequence = B passes.
- We had some success with something akin to this in the ESM-1V paper.
Args:
x: [B, L, K]
mask: [B, L, 1] # n masks in L dim
if None, use a mask of all ones.
Returns:
logits: [B, n, K]
"""
B, L, K = x.shape
if mask is None:
mask = torch.ones(B, L, 1, dtype=torch.bool, device=x.device)
n = assert_valid_mask(mask, x=x) # this is gross.
# Get coords of set bits. [B*n]
b_coords, l_coords, _ = mask.nonzero(as_tuple=True) # [B*n], [B*n]
# Double-check of mask assumptions.
assert torch.equal(b_coords, torch.repeat_interleave(torch.arange(B).to(x.device), n).to(x.device))
# naming: _m1 = mask1. (B*n leading dim.)
x_m1 = x[b_coords] # [B*n, L, K]
mask_m1 = F.one_hot(l_coords, L).unsqueeze(-1).bool() # [B*n, L, 1]
# Apply mask, mlm forward, select logits.
x_masked_m1 = apply_mask(x_m1, mask_m1, mlm.vocab.mask_idx) # [B*n, L, K]
lm_logits_m1 = mlm(x_masked_m1)['logits']
lm_logits_select = lm_logits_m1.masked_select(mask_m1).reshape(B, n, K) # [B, n, K]
# Mlm outputs 'logits' = not logprobabilities, but pre-softmax values.
# We must log-softmax here to convert from pre-softmax logits -> logprobs.
lm_logprobs_select = torch.log_softmax(lm_logits_select, axis=-1)
return lm_logprobs_select # [B, n, K]
class WrapLM(nn.Module, ABC):
def __init__(self, LM, vocab):
super().__init__()
self.model = LM
self.vocab = vocab
@abstractmethod
def forward(self, seq1h, **kwargs):
raise NotImplementedError
class WrapLmEsm(WrapLM):
def forward(self, seq1h):
B, L, K = seq1h.shape
seq1h, seq_start_idx, seq_end_idx = self._prepare_seq(seq1h)
seq = seq1h.argmax(-1)
out = self.model(seq)
return {
'logits': out['logits'][:, seq_start_idx:seq_end_idx, :K],
}
def _prepare_seq(self, seq1h):
assert is_1hot_tensor(seq1h)
B, L, K = seq1h.shape
# Prepend bos/cls and append eos.
seq1h = add_eos_bos(seq1h, bos_idx=self.vocab.cls_idx, eos_idx=self.vocab.eos_idx)
seq_start_idx = 1
seq_end_idx = L + 1
# In some cases, a vocab padded to 64 positions
# with dummy character was used.
# As a workaround, pad K dimension, then remove this portion after
embed_dim = self.model.embed_tokens.weight.size(0)
if embed_dim != K:
seq1h = torch.cat([
seq1h,
torch.zeros(B, L+2, embed_dim - K).to(seq1h)
], -1)
return seq1h, seq_start_idx, seq_end_idx