Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import math | |
| logger = logging.getLogger(__name__) | |
| from typing import Optional, Dict, List, Any, Union | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from abc import ABC, abstractmethod | |
| from .tensor import is_1hot_tensor, add_eos_bos | |
| from .masking import apply_mask, assert_valid_mask | |
| def lm_marginal(mlm, x, mask): | |
| """ | |
| Utility to extract logprobs from an MLM, given an input tensor and a mask. | |
| NOTE: | |
| - **Mask each set bit in mask's L-dim separately.** - (as opposed to all at once.) | |
| - Therefore, performs B*n forward passes. | |
| - In the future, we might try just using a single forward pass per sequence = B passes. | |
| - We had some success with something akin to this in the ESM-1V paper. | |
| Args: | |
| x: [B, L, K] | |
| mask: [B, L, 1] # n masks in L dim | |
| if None, use a mask of all ones. | |
| Returns: | |
| logits: [B, n, K] | |
| """ | |
| B, L, K = x.shape | |
| if mask is None: | |
| mask = torch.ones(B, L, 1, dtype=torch.bool, device=x.device) | |
| n = assert_valid_mask(mask, x=x) # this is gross. | |
| # Get coords of set bits. [B*n] | |
| b_coords, l_coords, _ = mask.nonzero(as_tuple=True) # [B*n], [B*n] | |
| # Double-check of mask assumptions. | |
| assert torch.equal(b_coords, torch.repeat_interleave(torch.arange(B).to(x.device), n).to(x.device)) | |
| # naming: _m1 = mask1. (B*n leading dim.) | |
| x_m1 = x[b_coords] # [B*n, L, K] | |
| mask_m1 = F.one_hot(l_coords, L).unsqueeze(-1).bool() # [B*n, L, 1] | |
| # Apply mask, mlm forward, select logits. | |
| x_masked_m1 = apply_mask(x_m1, mask_m1, mlm.vocab.mask_idx) # [B*n, L, K] | |
| lm_logits_m1 = mlm(x_masked_m1)['logits'] | |
| lm_logits_select = lm_logits_m1.masked_select(mask_m1).reshape(B, n, K) # [B, n, K] | |
| # Mlm outputs 'logits' = not logprobabilities, but pre-softmax values. | |
| # We must log-softmax here to convert from pre-softmax logits -> logprobs. | |
| lm_logprobs_select = torch.log_softmax(lm_logits_select, axis=-1) | |
| return lm_logprobs_select # [B, n, K] | |
| class WrapLM(nn.Module, ABC): | |
| def __init__(self, LM, vocab): | |
| super().__init__() | |
| self.model = LM | |
| self.vocab = vocab | |
| def forward(self, seq1h, **kwargs): | |
| raise NotImplementedError | |
| class WrapLmEsm(WrapLM): | |
| def forward(self, seq1h): | |
| B, L, K = seq1h.shape | |
| seq1h, seq_start_idx, seq_end_idx = self._prepare_seq(seq1h) | |
| seq = seq1h.argmax(-1) | |
| out = self.model(seq) | |
| return { | |
| 'logits': out['logits'][:, seq_start_idx:seq_end_idx, :K], | |
| } | |
| def _prepare_seq(self, seq1h): | |
| assert is_1hot_tensor(seq1h) | |
| B, L, K = seq1h.shape | |
| # Prepend bos/cls and append eos. | |
| seq1h = add_eos_bos(seq1h, bos_idx=self.vocab.cls_idx, eos_idx=self.vocab.eos_idx) | |
| seq_start_idx = 1 | |
| seq_end_idx = L + 1 | |
| # In some cases, a vocab padded to 64 positions | |
| # with dummy character was used. | |
| # As a workaround, pad K dimension, then remove this portion after | |
| embed_dim = self.model.embed_tokens.weight.size(0) | |
| if embed_dim != K: | |
| seq1h = torch.cat([ | |
| seq1h, | |
| torch.zeros(B, L+2, embed_dim - K).to(seq1h) | |
| ], -1) | |
| return seq1h, seq_start_idx, seq_end_idx | |