| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from tqdm import tqdm |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from typing import Optional, Tuple, Union |
| import torch |
| import torch.nn as nn |
| from model.architectures.transformer import EncoderDecoderTransformer |
| from model.architectures.crossformer import EncoderDecoderCrossFormer |
| from model.hf_configs import Seq2SeqConfig, Seq2SeqCrossConfig |
| from einops import rearrange |
|
|
| class Seq2SeqTransformer(PreTrainedModel): |
| """ |
| Custom Transformer for Sequence to Sequence tasks. |
| """ |
| config_class = Seq2SeqConfig |
| base_model_prefix = "transformer" |
| |
| def __init__(self, config: PretrainedConfig, device: Optional[str]=None): |
| super().__init__(config) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| self.transformer = EncoderDecoderTransformer( |
| src_vocab_size=config.vocab_size_src, |
| tgt_vocab_size=config.vocab_size_tgt, |
| embed_dim=config.d_model, |
| num_heads=config.n_heads, |
| ff_dim=config.d_ff, |
| num_encoder_layers=config.n_layers, |
| num_decoder_layers=config.n_layers, |
| max_seq_length=config.sequence_length |
| ) |
| |
| def _init_weights(self, module: nn.Module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
|
|
| def _create_padding_mask(self, ids: torch.LongTensor) -> torch.DoubleTensor: |
| """Creates a mask to avoid padded tokens to be interfering with attention""" |
| |
| is_padding = ids.eq(self.config.pad_token_id) |
| |
| |
| mask = is_padding.float() |
| mask = mask.masked_fill(is_padding, float('-inf')) |
| mask = mask.masked_fill(~is_padding, 1.0) |
| return mask |
|
|
| def _shift_right(self, x: torch.LongTensor) -> torch.LongTensor: |
| """Helper method to prepare decoder inputs (teacher forcing) by shifting right label tokens""" |
| shifted = torch.full( |
| (*x.shape[:-1], 1), |
| self.config.bos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
| shifted = torch.cat([shifted, x[:, :-1]], dim=-1) |
| return shifted |
| |
| def _add_beginning_of_stream(self, x: torch.LongTensor) -> torch.LongTensor: |
| """ |
| Helper method to add BOS token to the beginning of input sequences |
| """ |
| bos = torch.full( |
| (*x.shape[:-1], 1), |
| self.config.bos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
|
|
| return torch.cat([bos, x], dim=-1) |
|
|
| def _add_end_of_stream(self, x: torch.LongTensor) -> torch.LongTensor: |
| """Helper method to add EOS token to the end of label sequences""" |
| eos = torch.full( |
| (*x.shape[:-1], 1), |
| self.config.eos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
| return torch.cat([x, eos], dim=-1) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| labels: Optional[torch.LongTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| **kwargs |
| ) -> Union[Tuple, dict]: |
| |
| |
| |
| input_ids = self._add_beginning_of_stream(input_ids) |
|
|
| |
| labels = self._add_end_of_stream(labels) |
| |
| if decoder_input_ids is None and labels is not None: |
| decoder_input_ids = self._shift_right(labels) |
| |
| src_key_padding_mask = self._create_padding_mask(input_ids) |
| tgt_key_padding_mask = self._create_padding_mask(decoder_input_ids) |
| |
| |
| outputs = self.transformer( |
| src=input_ids, |
| tgt=decoder_input_ids, |
| src_mask=attention_mask, |
| tgt_mask=decoder_attention_mask, |
| src_key_padding_mask=src_key_padding_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask |
| ) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
| loss = loss_fct(outputs.view(-1, self.config.vocab_size_tgt), labels.view(-1)) |
|
|
| return dict( |
| loss=loss, |
| logits=outputs, |
| ) |
| |
| def generate( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| max_length: Optional[int] = None, |
| temperature: float = 1.0, |
| do_sample: bool = False, |
| **kwargs |
| ) -> torch.LongTensor: |
| |
| batch_size = input_ids.shape[0] |
| max_length = max_length or self.config.max_length or 128 |
| |
| decoder_input_ids = torch.full( |
| (batch_size, 1), |
| self.config.bos_token_id, |
| dtype=torch.long, |
| device=input_ids.device |
| ) |
| |
| for _ in range(max_length - 1): |
| outputs = self.forward( |
| input_ids=input_ids, |
| decoder_input_ids=decoder_input_ids, |
| attention_mask=attention_mask, |
| ) |
| |
| next_token_logits = outputs["logits"][:, -1, :] |
| |
| if do_sample: |
| |
| scaled_logits = next_token_logits / temperature |
| |
| next_token_probs = self.softmax(scaled_logits) |
| |
| next_token = torch.multinomial( |
| next_token_probs, num_samples=1 |
| ).squeeze(-1) |
| else: |
| |
| next_token = next_token_logits.argmax(dim=-1) |
| |
| decoder_input_ids = torch.cat( |
| [decoder_input_ids, next_token.unsqueeze(-1)], |
| dim=-1 |
| ) |
| |
| |
| if (decoder_input_ids == self.config.eos_token_id).any(dim=-1).all(): |
| break |
| |
| return decoder_input_ids |
|
|
|
|
| class Seq2SeqCrossFormer(Seq2SeqTransformer): |
| """CrossFormer wrapper predicting over a discrete vocabulatory.""" |
| config_class = Seq2SeqCrossConfig |
| |
| def __init__(self, config: PretrainedConfig): |
| super().__init__(config) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| self.transformer = EncoderDecoderCrossFormer( |
| source_sequence_dimension=config.source_sequence_dimension, |
| target_sequence_dimension=config.target_sequence_dimension, |
| router_dim=config.router_dim, |
| src_vocab_size=config.vocab_size_src, |
| tgt_vocab_size=config.vocab_size_tgt, |
| embed_dim=config.d_model, |
| num_heads=config.n_heads, |
| ff_dim=config.d_ff, |
| num_encoder_layers=config.n_layers, |
| num_decoder_layers=config.n_layers, |
| max_seq_length=config.sequence_length |
| ) |
|
|
| def _shift_right(self, x: torch.LongTensor) -> torch.LongTensor: |
| """ |
| Helper method to prepare decoder inputs (teacher forcing) by shifting right label tokens. |
| Handles 3D (B, S, C) tensors |
| """ |
| |
| shape = list(x.shape) |
| shape[-2] = 1 |
| |
| shifted = torch.full( |
| shape, |
| self.config.bos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
| shifted = torch.cat([shifted, x[..., :-1, :]], dim=-2) |
| return shifted |
| |
| def _add_beginning_of_stream(self, x: torch.LongTensor) -> torch.LongTensor: |
| """ |
| Helper method to add BOS token to the beginning of input sequences. |
| Handles 3D (B, S, C) tensors |
| """ |
| shape = list(x.shape) |
| shape[-2] = 1 |
| sos = torch.full( |
| shape, |
| self.config.bos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
|
|
| return torch.cat([sos, x], dim=-2) |
|
|
| def _add_end_of_stream(self, x: torch.LongTensor) -> torch.LongTensor: |
| """ |
| Helper method to add EOS token to the end of label sequences. |
| Handles 3D (B, S, C) tensors |
| """ |
| |
| shape = list(x.shape) |
| shape[-2] = 1 |
| |
| eos = torch.full( |
| shape, |
| self.config.eos_token_id, |
| dtype=x.dtype, |
| device=x.device |
| ) |
| return torch.cat([x, eos], dim=-2) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| labels: Optional[torch.LongTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| **kwargs |
| ): |
| |
|
|
| |
| input_ids = self._add_beginning_of_stream(input_ids) |
|
|
| |
| if labels is not None: |
| labels = self._add_end_of_stream(labels) |
|
|
| |
| if decoder_input_ids is None and labels is not None: |
| decoder_input_ids = self._shift_right(labels) |
| |
| src_src_key_padding_time_mask = rearrange( |
| self._create_padding_mask( |
| input_ids |
| ), |
| 'b s c -> (b c) s' |
| ) |
|
|
| tgt_tgt_key_padding_time_mask = rearrange( |
| self._create_padding_mask( |
| decoder_input_ids |
| ), |
| 'b s c -> (b c) s' |
| ) |
| |
| |
| outputs = self.transformer( |
| src=input_ids, |
| tgt=decoder_input_ids, |
| src_src_time_mask=kwargs.get("src_src_time_mask"), |
| src_src_dimension_mask=kwargs.get("src_src_dimension_mask"), |
| src_src_key_padding_time_mask=src_src_key_padding_time_mask, |
| tgt_tgt_time_mask=kwargs.get("tgt_tgt_time_mask"), |
| tgt_tgt_dimension_mask=kwargs.get("tgt_tgt_dimension_mask"), |
| tgt_tgt_key_padding_time_mask=tgt_tgt_key_padding_time_mask, |
| tgt_src_dimension_mask=kwargs.get("tgt_src_dimension_mask") |
| ) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss( |
| ignore_index=self.config.pad_token_id |
| ) |
| loss = loss_fct( |
| outputs.view(-1, self.config.vocab_size_tgt), labels.view(-1) |
| ) |
|
|
| return dict( |
| loss=loss, |
| logits=outputs, |
| ) |
|
|
| def generate( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor]=None, |
| max_length: Optional[int]=None, |
| temperature: float=1.0, |
| do_sample: bool=False, |
| **kwargs |
| ) -> torch.LongTensor: |
|
|
| batch_size, timesteps, channels = input_ids.shape |
|
|
| src_key_padding_mask = self._create_padding_mask(input_ids) |
| max_length = max_length or self.config.max_length or 128 |
|
|
| decoder_input_ids = torch.full( |
| (batch_size, timesteps + 1, 306), |
| self.config.pad_token_id, |
| dtype=torch.long, |
| device=input_ids.device |
| ) |
|
|
| |
| decoder_input_ids[:, 0, :] = self.config.bos_token_id |
| |
| for t in tqdm(range(timesteps + max_length)): |
| outputs = self.forward( |
| input_ids=input_ids, |
| decoder_input_ids=decoder_input_ids, |
| attention_mask=attention_mask |
| ) |
| |
| |
| next_token_logits = outputs["logits"][:, t, :] |
| |
| if do_sample: |
| scaled_logits = next_token_logits / temperature |
| next_token_probs = self.softmax(scaled_logits) |
| next_token = torch.multinomial( |
| next_token_probs, num_samples=1 |
| ).squeeze(-1) |
| else: |
| next_token = next_token_logits.argmax(dim=-1) |
|
|
| |
| decoder_input_ids[:, t, :] = next_token |
| |
| |
| if (next_token == self.config.eos_token_id).all(): |
| break |
|
|
| decoder_input_ids = decoder_input_ids[:, -(timesteps+1):, :] |
| |
| return decoder_input_ids |
|
|
|
|