import torch import torch.nn as nn from config import Im2LatexTransformerConfig from transformers import PreTrainedModel class CNN(nn.Module): def __init__(self, config: Im2LatexTransformerConfig): """ Builds a CNN model Args: config (Im2LatexTransformerConfig): Configuration for the model """ super(CNN, self).__init__() self.conv_blocks = nn.Sequential( nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Dropout2d(p=config.dropout), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Dropout2d(p=config.dropout), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Dropout2d(p=config.dropout), nn.MaxPool2d(2, 2) ) self.projection = nn.Linear(128, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Passes the input through the model Args: x (torch.Tensor): Input Returns: torch.Tensor: Output """ # Garante que x tenha dimensão de batch: (B, C, H, W) if x.dim() == 3: x = x.unsqueeze(0) # 1. Passa pelas convoluções x = self.conv_blocks(x) # -> (B, C=128, H_out, W_out) # 2. Prepara para o transformer B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) # (B, S=H*W, C) # 3. Projeta para d_model e aplica dropout x = self.projection(x) # (B, S, d_model) x = self.dropout(x) return x class Decoder(nn.Module): def __init__(self, config: Im2LatexTransformerConfig): """ Builds a Transformer decoder Args: config (Im2LatexTransformerConfig): Configuration for the model """ super(Decoder, self).__init__() self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.pos_embedding = nn.Embedding(config.max_len, config.d_model) decoder_layer = nn.TransformerDecoderLayer(config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_layers) self.output_proj = nn.Linear(config.d_model, config.vocab_size) self.dropout = nn.Dropout(config.dropout) def forward(self, tokens: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor=None, tgt_key_padding_mask: torch.Tensor=None) -> torch.Tensor: """ Passes the input through the decoder Args: tokens (torch.Tensor): List of tokens memory (torch.Tensor): Memory tgt_mask (torch.Tensor, optional): Attention mask. Defaults to None. tgt_key_padding_mask (torch.Tensor, optional): Padding mask. Defaults to None. Returns: torch.Tensor: Next tokens logits """ # tokens: (Batch, seq_len) batch_size, seq_len = tokens.shape device = tokens.device # 1. embeddings do token + posicional token_emb = self.embedding(tokens) # Shape (Batch, seq_len, d_model) positions = torch.arange(0, seq_len, device=device).unsqueeze(0) # Shape (1, seq_len) pos_emb = self.pos_embedding(positions) # Shape: (1, S, D) # 2. Adiciona embeddings e aplica dropout x = self.dropout(token_emb + pos_emb) # 3. Passa pelo decoder out = self.transformer_decoder( tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask ) # 4. Passa pela projection logits = self.output_proj(out) return logits # A classe Transformer wrapper também está correta. class Im2LatexTransformer(PreTrainedModel): config_class = Im2LatexTransformerConfig def __init__(self, config): """ Builds a Transformer Args: config (Im2LatexTransformerConfig): Configuration for the model """ super(Im2LatexTransformer, self).__init__(config) self.encoder = CNN(config) self.decoder = Decoder(config) def forward(self, pixel_values: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_padding_mask: torch.Tensor=None) -> torch.Tensor: """ Passes the input through the transformer Args: pixel_values (torch.Tensor): Input images decoder_input_ids (torch.Tensor): Decoder input tokens decoder_padding_mask (torch.Tensor, optional): Padding mask for the decoder. Defaults to None. Returns: torch.Tensor: Next tokens logits """ device = pixel_values.device # 1. Passa pela CNN memory = self.encoder(pixel_values) # 2. Prepara o decoder tgt_mask = None if decoder_input_ids is not None: seq_len = decoder_input_ids.size(1) tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1) # 3. Passa pelo decoder logits = self.decoder(decoder_input_ids, memory, tgt_mask, decoder_padding_mask) return logits @torch.no_grad() def generate(self, pixel_values: torch.Tensor, max_length: int = 512, sos_token_id: int = 1, eos_token_id: int = 2): """ Generates a sequence of tokens from the input images Args: pixel_values (torch.Tensor): Input images max_length (int, optional): Maximum length of the generated sequence. Defaults to 512. sos_token_id (int, optional): Start of sequence token ID. Defaults to 1. eos_token_id (int, optional): End of sequence token ID. Defaults to 2. Returns: torch.Tensor: Generated sequence of tokens """ self.eval() # coloca o modelo em modo de avaliação if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) pixel_values = pixel_values.to(self.device) generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device) for _ in range(max_length): logits = self(pixel_values, generated_sequence) # forward do modelo last_logits = logits[0, -1, :] # pega a última predição next_token_idx = last_logits.argmax(-1).item() # greedy decoding generated_sequence = torch.cat([ generated_sequence, torch.tensor([[next_token_idx]], dtype=torch.long, device=self.device) ], dim=1) if next_token_idx == eos_token_id: break return generated_sequence.squeeze(0) # remove dimensão de batch