File size: 6,792 Bytes
b0377a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | import torch
import torch.nn as nn
from config import Im2LatexTransformerConfig
from transformers import PreTrainedModel
class CNN(nn.Module):
def __init__(self, config: Im2LatexTransformerConfig):
"""
Builds a CNN model
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(CNN, self).__init__()
self.conv_blocks = nn.Sequential(
nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Dropout2d(p=config.dropout),
nn.MaxPool2d(2, 2)
)
self.projection = nn.Linear(128, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Passes the input through the model
Args:
x (torch.Tensor): Input
Returns:
torch.Tensor: Output
"""
# Garante que x tenha dimensão de batch: (B, C, H, W)
if x.dim() == 3:
x = x.unsqueeze(0)
# 1. Passa pelas convoluções
x = self.conv_blocks(x) # -> (B, C=128, H_out, W_out)
# 2. Prepara para o transformer
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) # (B, S=H*W, C)
# 3. Projeta para d_model e aplica dropout
x = self.projection(x) # (B, S, d_model)
x = self.dropout(x)
return x
class Decoder(nn.Module):
def __init__(self, config: Im2LatexTransformerConfig):
"""
Builds a Transformer decoder
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(Decoder, self).__init__()
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embedding = nn.Embedding(config.max_len, config.d_model)
decoder_layer = nn.TransformerDecoderLayer(config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
self.output_proj = nn.Linear(config.d_model, config.vocab_size)
self.dropout = nn.Dropout(config.dropout)
def forward(self,
tokens: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor=None,
tgt_key_padding_mask: torch.Tensor=None) -> torch.Tensor:
"""
Passes the input through the decoder
Args:
tokens (torch.Tensor): List of tokens
memory (torch.Tensor): Memory
tgt_mask (torch.Tensor, optional): Attention mask. Defaults to None.
tgt_key_padding_mask (torch.Tensor, optional): Padding mask. Defaults to None.
Returns:
torch.Tensor: Next tokens logits
"""
# tokens: (Batch, seq_len)
batch_size, seq_len = tokens.shape
device = tokens.device
# 1. embeddings do token + posicional
token_emb = self.embedding(tokens) # Shape (Batch, seq_len, d_model)
positions = torch.arange(0, seq_len, device=device).unsqueeze(0) # Shape (1, seq_len)
pos_emb = self.pos_embedding(positions) # Shape: (1, S, D)
# 2. Adiciona embeddings e aplica dropout
x = self.dropout(token_emb + pos_emb)
# 3. Passa pelo decoder
out = self.transformer_decoder(
tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask
)
# 4. Passa pela projection
logits = self.output_proj(out)
return logits
# A classe Transformer wrapper também está correta.
class Im2LatexTransformer(PreTrainedModel):
config_class = Im2LatexTransformerConfig
def __init__(self, config):
"""
Builds a Transformer
Args:
config (Im2LatexTransformerConfig): Configuration for the model
"""
super(Im2LatexTransformer, self).__init__(config)
self.encoder = CNN(config)
self.decoder = Decoder(config)
def forward(self,
pixel_values: torch.Tensor,
decoder_input_ids: torch.Tensor,
decoder_padding_mask: torch.Tensor=None) -> torch.Tensor:
"""
Passes the input through the transformer
Args:
pixel_values (torch.Tensor): Input images
decoder_input_ids (torch.Tensor): Decoder input tokens
decoder_padding_mask (torch.Tensor, optional): Padding mask for the decoder. Defaults to None.
Returns:
torch.Tensor: Next tokens logits
"""
device = pixel_values.device
# 1. Passa pela CNN
memory = self.encoder(pixel_values)
# 2. Prepara o decoder
tgt_mask = None
if decoder_input_ids is not None:
seq_len = decoder_input_ids.size(1)
tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1)
# 3. Passa pelo decoder
logits = self.decoder(decoder_input_ids, memory, tgt_mask, decoder_padding_mask)
return logits
@torch.no_grad()
def generate(self, pixel_values: torch.Tensor, max_length: int = 512, sos_token_id: int = 1, eos_token_id: int = 2):
"""
Generates a sequence of tokens from the input images
Args:
pixel_values (torch.Tensor): Input images
max_length (int, optional): Maximum length of the generated sequence. Defaults to 512.
sos_token_id (int, optional): Start of sequence token ID. Defaults to 1.
eos_token_id (int, optional): End of sequence token ID. Defaults to 2.
Returns:
torch.Tensor: Generated sequence of tokens
"""
self.eval() # coloca o modelo em modo de avaliação
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
pixel_values = pixel_values.to(self.device)
generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device)
for _ in range(max_length):
logits = self(pixel_values, generated_sequence) # forward do modelo
last_logits = logits[0, -1, :] # pega a última predição
next_token_idx = last_logits.argmax(-1).item() # greedy decoding
generated_sequence = torch.cat([
generated_sequence,
torch.tensor([[next_token_idx]], dtype=torch.long, device=self.device)
], dim=1)
if next_token_idx == eos_token_id:
break
return generated_sequence.squeeze(0) # remove dimensão de batch |