Im2LatexTransformer / modeling.py
pedrolcs63's picture
Upload folder using huggingface_hub
b0377a8 verified
import torch
import torch.nn as nn
from config import Im2LatexTransformerConfig
from transformers import PreTrainedModel
class CNN(nn.Module):
def __init__(self, config: Im2LatexTransformerConfig):
"""
Builds a CNN model
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(CNN, self).__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2)
)
self.projection = nn.Linear(128, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Passes the input through the model
Args:
x (torch.Tensor): Input
Returns:
torch.Tensor: Output
"""
# Garante que x tenha dimensão de batch: (B, C, H, W)
if x.dim() == 3:
x = x.unsqueeze(0)
# 1. Passa pelas convoluções
x = self.conv_blocks(x) # -> (B, C=128, H_out, W_out)
# 2. Prepara para o transformer
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) # (B, S=H*W, C)
# 3. Projeta para d_model e aplica dropout
x = self.projection(x) # (B, S, d_model)
x = self.dropout(x)
return x
class Decoder(nn.Module):
def __init__(self, config: Im2LatexTransformerConfig):
"""
Builds a Transformer decoder
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(Decoder, self).__init__()
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embedding = nn.Embedding(config.max_len, config.d_model)
decoder_layer = nn.TransformerDecoderLayer(config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
self.output_proj = nn.Linear(config.d_model, config.vocab_size)
self.dropout = nn.Dropout(config.dropout)
def forward(self,
tokens: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor=None,
tgt_key_padding_mask: torch.Tensor=None) -> torch.Tensor:
"""
Passes the input through the decoder
Args:
tokens (torch.Tensor): List of tokens
memory (torch.Tensor): Memory
tgt_mask (torch.Tensor, optional): Attention mask. Defaults to None.
tgt_key_padding_mask (torch.Tensor, optional): Padding mask. Defaults to None.
Returns:
torch.Tensor: Next tokens logits
"""
# tokens: (Batch, seq_len)
batch_size, seq_len = tokens.shape
device = tokens.device
# 1. embeddings do token + posicional
token_emb = self.embedding(tokens) # Shape (Batch, seq_len, d_model)
positions = torch.arange(0, seq_len, device=device).unsqueeze(0) # Shape (1, seq_len)
pos_emb = self.pos_embedding(positions) # Shape: (1, S, D)
# 2. Adiciona embeddings e aplica dropout
x = self.dropout(token_emb + pos_emb)
# 3. Passa pelo decoder
out = self.transformer_decoder(
tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask
)
# 4. Passa pela projection
logits = self.output_proj(out)
return logits
# A classe Transformer wrapper também está correta.
class Im2LatexTransformer(PreTrainedModel):
config_class = Im2LatexTransformerConfig
def __init__(self, config):
"""
Builds a Transformer
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(Im2LatexTransformer, self).__init__(config)
self.encoder = CNN(config)
self.decoder = Decoder(config)
def forward(self,
pixel_values: torch.Tensor,
decoder_input_ids: torch.Tensor,
decoder_padding_mask: torch.Tensor=None) -> torch.Tensor:
"""
Passes the input through the transformer
Args:
pixel_values (torch.Tensor): Input images
decoder_input_ids (torch.Tensor): Decoder input tokens
decoder_padding_mask (torch.Tensor, optional): Padding mask for the decoder. Defaults to None.
Returns:
torch.Tensor: Next tokens logits
"""
device = pixel_values.device
# 1. Passa pela CNN
memory = self.encoder(pixel_values)
# 2. Prepara o decoder
tgt_mask = None
if decoder_input_ids is not None:
seq_len = decoder_input_ids.size(1)
tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1)
# 3. Passa pelo decoder
logits = self.decoder(decoder_input_ids, memory, tgt_mask, decoder_padding_mask)
return logits
@torch.no_grad()
def generate(self, pixel_values: torch.Tensor, max_length: int = 512, sos_token_id: int = 1, eos_token_id: int = 2):
"""
Generates a sequence of tokens from the input images
Args:
pixel_values (torch.Tensor): Input images
max_length (int, optional): Maximum length of the generated sequence. Defaults to 512.
sos_token_id (int, optional): Start of sequence token ID. Defaults to 1.
eos_token_id (int, optional): End of sequence token ID. Defaults to 2.
Returns:
torch.Tensor: Generated sequence of tokens
"""
self.eval() # coloca o modelo em modo de avaliação
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
pixel_values = pixel_values.to(self.device)
generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device)
for _ in range(max_length):
logits = self(pixel_values, generated_sequence) # forward do modelo
last_logits = logits[0, -1, :] # pega a última predição
next_token_idx = last_logits.argmax(-1).item() # greedy decoding
generated_sequence = torch.cat([
generated_sequence,
torch.tensor([[next_token_idx]], dtype=torch.long, device=self.device)
], dim=1)
if next_token_idx == eos_token_id:
break
return generated_sequence.squeeze(0) # remove dimensão de batch