| import torch |
| import torch.nn as nn |
| from config import Im2LatexTransformerConfig |
| from transformers import PreTrainedModel |
|
|
| class CNN(nn.Module): |
| def __init__(self, config: Im2LatexTransformerConfig): |
| """ |
| Builds a CNN model |
| |
| Args: |
| config (Im2LatexTransformerConfig): Configuration for the model |
| """ |
| super(CNN, self).__init__() |
| |
| self.conv_blocks = nn.Sequential( |
| nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Dropout2d(p=config.dropout), |
| nn.MaxPool2d(2, 2), |
|
|
| nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Dropout2d(p=config.dropout), |
| nn.MaxPool2d(2, 2), |
|
|
| nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Dropout2d(p=config.dropout), |
| nn.MaxPool2d(2, 2) |
| ) |
| |
| self.projection = nn.Linear(128, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Passes the input through the model |
| |
| Args: |
| x (torch.Tensor): Input |
| |
| Returns: |
| torch.Tensor: Output |
| """ |
| |
| if x.dim() == 3: |
| x = x.unsqueeze(0) |
|
|
| |
| x = self.conv_blocks(x) |
| |
| |
| B, C, H, W = x.shape |
| x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) |
| |
| |
| x = self.projection(x) |
| x = self.dropout(x) |
| |
| return x |
|
|
| class Decoder(nn.Module): |
| def __init__(self, config: Im2LatexTransformerConfig): |
| """ |
| Builds a Transformer decoder |
| |
| Args: |
| config (Im2LatexTransformerConfig): Configuration for the model |
| """ |
| super(Decoder, self).__init__() |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_embedding = nn.Embedding(config.max_len, config.d_model) |
|
|
| decoder_layer = nn.TransformerDecoderLayer(config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True) |
| self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_layers) |
|
|
| self.output_proj = nn.Linear(config.d_model, config.vocab_size) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, |
| tokens: torch.Tensor, |
| memory: torch.Tensor, |
| tgt_mask: torch.Tensor=None, |
| tgt_key_padding_mask: torch.Tensor=None) -> torch.Tensor: |
| """ |
| Passes the input through the decoder |
| |
| Args: |
| tokens (torch.Tensor): List of tokens |
| memory (torch.Tensor): Memory |
| tgt_mask (torch.Tensor, optional): Attention mask. Defaults to None. |
| tgt_key_padding_mask (torch.Tensor, optional): Padding mask. Defaults to None. |
| |
| Returns: |
| torch.Tensor: Next tokens logits |
| """ |
| |
| batch_size, seq_len = tokens.shape |
| device = tokens.device |
| |
| |
| token_emb = self.embedding(tokens) |
| positions = torch.arange(0, seq_len, device=device).unsqueeze(0) |
| pos_emb = self.pos_embedding(positions) |
|
|
| |
| x = self.dropout(token_emb + pos_emb) |
|
|
| |
| out = self.transformer_decoder( |
| tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask |
| ) |
|
|
| |
| logits = self.output_proj(out) |
| |
| return logits |
|
|
| |
| class Im2LatexTransformer(PreTrainedModel): |
| config_class = Im2LatexTransformerConfig |
|
|
| def __init__(self, config): |
| """ |
| Builds a Transformer |
| |
| Args: |
| config (Im2LatexTransformerConfig): Configuration for the model |
| """ |
| super(Im2LatexTransformer, self).__init__(config) |
| self.encoder = CNN(config) |
| self.decoder = Decoder(config) |
|
|
| def forward(self, |
| pixel_values: torch.Tensor, |
| decoder_input_ids: torch.Tensor, |
| decoder_padding_mask: torch.Tensor=None) -> torch.Tensor: |
| """ |
| Passes the input through the transformer |
| |
| Args: |
| pixel_values (torch.Tensor): Input images |
| decoder_input_ids (torch.Tensor): Decoder input tokens |
| decoder_padding_mask (torch.Tensor, optional): Padding mask for the decoder. Defaults to None. |
| |
| Returns: |
| torch.Tensor: Next tokens logits |
| """ |
| device = pixel_values.device |
| |
| |
| memory = self.encoder(pixel_values) |
|
|
| |
| tgt_mask = None |
| if decoder_input_ids is not None: |
| seq_len = decoder_input_ids.size(1) |
| tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1) |
|
|
| |
| logits = self.decoder(decoder_input_ids, memory, tgt_mask, decoder_padding_mask) |
| return logits |
| |
| @torch.no_grad() |
| def generate(self, pixel_values: torch.Tensor, max_length: int = 512, sos_token_id: int = 1, eos_token_id: int = 2): |
| """ |
| Generates a sequence of tokens from the input images |
| |
| Args: |
| pixel_values (torch.Tensor): Input images |
| max_length (int, optional): Maximum length of the generated sequence. Defaults to 512. |
| sos_token_id (int, optional): Start of sequence token ID. Defaults to 1. |
| eos_token_id (int, optional): End of sequence token ID. Defaults to 2. |
| |
| Returns: |
| torch.Tensor: Generated sequence of tokens |
| """ |
| self.eval() |
|
|
| if pixel_values.dim() == 3: |
| pixel_values = pixel_values.unsqueeze(0) |
|
|
| pixel_values = pixel_values.to(self.device) |
|
|
| generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device) |
|
|
| for _ in range(max_length): |
| logits = self(pixel_values, generated_sequence) |
| last_logits = logits[0, -1, :] |
|
|
| next_token_idx = last_logits.argmax(-1).item() |
|
|
| generated_sequence = torch.cat([ |
| generated_sequence, |
| torch.tensor([[next_token_idx]], dtype=torch.long, device=self.device) |
| ], dim=1) |
|
|
| if next_token_idx == eos_token_id: |
| break |
|
|
| return generated_sequence.squeeze(0) |