pedrolcs63 commited on
Commit
17a24d5
·
verified ·
1 Parent(s): f7750a7

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.py +14 -32
  2. image_processor.py +47 -0
  3. modeling.py +128 -46
config.py CHANGED
@@ -1,38 +1,20 @@
1
  from transformers import PretrainedConfig
2
 
3
  class Im2LatexTransformerConfig(PretrainedConfig):
4
- """
5
- Configuration for the Im2LatexTransformer model
6
- """
7
- model_type = "im2latex_transformer"
 
 
 
 
 
 
 
8
 
9
- def __init__(
10
- self,
11
- vocab_size: int = 544,
12
- max_len: int = 512,
13
- d_model: int = 512,
14
- nhead: int = 8,
15
- num_decoder_layers: int = 6,
16
- dim_feedforward: int = 2048,
17
- in_channels: int = 1,
18
- dropout: float = 0.1,
19
- pad_token_id: int = 0,
20
- bos_token_id: int = 1,
21
- eos_token_id: int = 2,
22
- **kwargs
23
- ):
24
- self.vocab_size = vocab_size
25
- self.max_len = max_len
26
- self.d_model = d_model
27
- self.nhead = nhead
28
- self.num_decoder_layers = num_decoder_layers
29
- self.dim_feedforward = dim_feedforward
30
- self.in_channels = in_channels
31
- self.dropout = dropout
32
-
33
  super().__init__(
34
- pad_token_id=pad_token_id,
35
- bos_token_id=bos_token_id,
36
- eos_token_id=eos_token_id,
37
- **kwargs
38
  )
 
1
  from transformers import PretrainedConfig
2
 
3
  class Im2LatexTransformerConfig(PretrainedConfig):
4
+ model_type = "Im2LatexTransformer"
5
+
6
+ def __init__(self, **kwargs):
7
+ self.vocab_size = 544
8
+ self.max_len = 512
9
+ self.d_model = 512
10
+ self.nhead = 8
11
+ self.num_layers = 6
12
+ self.dim_feedforward = 2048
13
+ self.dropout = 0.1
14
+ self.in_channels = 1
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  super().__init__(
17
+ pad_token_id=0,
18
+ sos_token_id=1,
19
+ eos_token_id=2,
 
20
  )
image_processor.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageOps
2
+ import numpy as np
3
+ import torch
4
+ from transformers import ImageProcessingMixin
5
+ import os
6
+ import json
7
+
8
+ class Im2LatexProcessor(ImageProcessingMixin):
9
+ def __init__(self, image_size=(256, 256), **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.image_size = image_size
12
+
13
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
14
+ """
15
+ Process a PIL image and return a tensor.
16
+ """
17
+ img = image.convert("L")
18
+ img = ImageOps.pad(img, self.image_size, color=255)
19
+ arr = np.asarray(img, dtype=np.float32) / 255.0
20
+ arr = np.expand_dims(arr, 0) # (1, H, W)
21
+ return torch.tensor(arr, dtype=torch.float32)
22
+
23
+ def __call__(self, image_path: str) -> torch.Tensor:
24
+ """
25
+ Process an image file path.
26
+ """
27
+ image = Image.open(image_path)
28
+ return self.preprocess(image)
29
+
30
+ def save_pretrained(self, save_directory):
31
+ """
32
+ Save processor config
33
+ """
34
+ self.image_processor_config = {
35
+ "image_size": self.image_size,
36
+ }
37
+ with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
38
+ json.dump(self.image_processor_config, f)
39
+
40
+ @classmethod
41
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
42
+ """
43
+ Load processor config
44
+ """
45
+ with open(os.path.join(pretrained_model_name_or_path, "preprocessor_config.json"), "r") as f:
46
+ config = json.load(f)
47
+ return cls(**config)
modeling.py CHANGED
@@ -1,120 +1,202 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel
4
- from config import Im2LatexTransformerConfig
5
 
6
- # A classe CNN pode permanecer a mesma, pois é um módulo interno
7
  class CNN(nn.Module):
8
  def __init__(self, config: Im2LatexTransformerConfig):
 
 
 
 
 
 
9
  super(CNN, self).__init__()
 
10
  self.conv_blocks = nn.Sequential(
11
  nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1),
12
  nn.ReLU(),
13
  nn.Dropout2d(p=config.dropout),
14
  nn.MaxPool2d(2, 2),
 
15
  nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
16
  nn.ReLU(),
17
  nn.Dropout2d(p=config.dropout),
18
  nn.MaxPool2d(2, 2),
 
19
  nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
20
  nn.ReLU(),
21
  nn.Dropout2d(p=config.dropout),
22
  nn.MaxPool2d(2, 2)
23
  )
 
24
  self.projection = nn.Linear(128, config.d_model)
25
  self.dropout = nn.Dropout(config.dropout)
26
 
27
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
28
  if x.dim() == 3:
29
  x = x.unsqueeze(0)
30
- x = self.conv_blocks(x)
 
 
 
 
31
  B, C, H, W = x.shape
32
- x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
33
- x = self.projection(x)
 
 
34
  x = self.dropout(x)
 
35
  return x
36
 
37
- # A classe Decoder também pode permanecer a mesma
38
  class Decoder(nn.Module):
39
  def __init__(self, config: Im2LatexTransformerConfig):
 
 
 
 
 
 
40
  super(Decoder, self).__init__()
41
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
42
  self.pos_embedding = nn.Embedding(config.max_len, config.d_model)
43
- decoder_layer = nn.TransformerDecoderLayer(
44
- config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True
45
- )
46
- self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_decoder_layers)
47
  self.output_proj = nn.Linear(config.d_model, config.vocab_size)
48
  self.dropout = nn.Dropout(config.dropout)
49
 
50
- def forward(self, tokens, memory, tgt_mask=None, tgt_key_padding_mask=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  batch_size, seq_len = tokens.shape
52
  device = tokens.device
53
- token_emb = self.embedding(tokens)
54
- positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
55
- pos_emb = self.pos_embedding(positions)
 
 
 
 
56
  x = self.dropout(token_emb + pos_emb)
 
 
57
  out = self.transformer_decoder(
58
  tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask
59
  )
 
 
60
  logits = self.output_proj(out)
 
61
  return logits
62
 
 
63
  class Im2LatexTransformer(PreTrainedModel):
64
  config_class = Im2LatexTransformerConfig
65
 
66
- def __init__(self, config: Im2LatexTransformerConfig):
67
- super().__init__(config)
 
 
 
 
 
 
68
  self.encoder = CNN(config)
69
  self.decoder = Decoder(config)
70
 
71
  def forward(self,
72
  pixel_values: torch.Tensor,
73
  decoder_input_ids: torch.Tensor,
74
- decoder_attention_mask: torch.Tensor=None) -> torch.Tensor:
75
  """
76
- A assinatura do forward é adaptada para os nomes padrão do Hugging Face.
77
- - 'pixel_values' é o nome padrão para entradas de imagem.
78
- - 'decoder_input_ids' é o nome padrão para os tokens do decoder.
 
 
 
 
 
 
79
  """
80
- memory = self.encoder(pixel_values)
81
 
 
 
 
 
82
  tgt_mask = None
83
  if decoder_input_ids is not None:
84
- device = decoder_input_ids.device
85
  seq_len = decoder_input_ids.size(1)
86
  tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1)
87
 
88
- logits = self.decoder(
89
- tokens=decoder_input_ids,
90
- memory=memory,
91
- tgt_mask=tgt_mask,
92
- tgt_key_padding_mask=decoder_attention_mask
93
- )
94
  return logits
95
 
96
- # Adicionar um método generate para facilitar a inferência
97
  @torch.no_grad()
98
- def generate(self, pixel_values: torch.Tensor, max_length: int = 50, sos_token_id: int = 1, eos_token_id: int = 2):
99
- self.eval()
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  if pixel_values.dim() == 3:
102
  pixel_values = pixel_values.unsqueeze(0)
 
103
  pixel_values = pixel_values.to(self.device)
104
- memory = self.encoder(pixel_values)
105
-
106
  generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device)
107
 
108
- with torch.no_grad():
109
- for _ in range(max_length - 1):
110
- logits = self.decoder(generated_sequence, memory)
111
- next_token_id = logits[0, -1, :].argmax(-1).item()
112
- generated_sequence = torch.cat([
113
- generated_sequence,
114
- torch.tensor([[next_token_id]], dtype=torch.long, device=self.device)
115
- ], dim=1)
116
-
117
- if next_token_id == eos_token_id:
118
- break
119
-
120
- return generated_sequence
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from config import Im2LatexTransformerConfig
4
  from transformers import PreTrainedModel
 
5
 
 
6
  class CNN(nn.Module):
7
  def __init__(self, config: Im2LatexTransformerConfig):
8
+ """
9
+ Builds a CNN model
10
+
11
+ Args:
12
+ config (Im2LatexTransformerConfig): Configuration for the model
13
+ """
14
  super(CNN, self).__init__()
15
+
16
  self.conv_blocks = nn.Sequential(
17
  nn.Conv2d(config.in_channels, 32, kernel_size=3, stride=1, padding=1),
18
  nn.ReLU(),
19
  nn.Dropout2d(p=config.dropout),
20
  nn.MaxPool2d(2, 2),
21
+
22
  nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
23
  nn.ReLU(),
24
  nn.Dropout2d(p=config.dropout),
25
  nn.MaxPool2d(2, 2),
26
+
27
  nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
28
  nn.ReLU(),
29
  nn.Dropout2d(p=config.dropout),
30
  nn.MaxPool2d(2, 2)
31
  )
32
+
33
  self.projection = nn.Linear(128, config.d_model)
34
  self.dropout = nn.Dropout(config.dropout)
35
 
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Passes the input through the model
39
+
40
+ Args:
41
+ x (torch.Tensor): Input
42
+
43
+ Returns:
44
+ torch.Tensor: Output
45
+ """
46
+ # Garante que x tenha dimensão de batch: (B, C, H, W)
47
  if x.dim() == 3:
48
  x = x.unsqueeze(0)
49
+
50
+ # 1. Passa pelas convoluções
51
+ x = self.conv_blocks(x) # -> (B, C=128, H_out, W_out)
52
+
53
+ # 2. Prepara para o transformer
54
  B, C, H, W = x.shape
55
+ x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) # (B, S=H*W, C)
56
+
57
+ # 3. Projeta para d_model e aplica dropout
58
+ x = self.projection(x) # (B, S, d_model)
59
  x = self.dropout(x)
60
+
61
  return x
62
 
 
63
  class Decoder(nn.Module):
64
  def __init__(self, config: Im2LatexTransformerConfig):
65
+ """
66
+ Builds a Transformer decoder
67
+
68
+ Args:
69
+ config (Im2LatexTransformerConfig): Configuration for the model
70
+ """
71
  super(Decoder, self).__init__()
72
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
73
  self.pos_embedding = nn.Embedding(config.max_len, config.d_model)
74
+
75
+ decoder_layer = nn.TransformerDecoderLayer(config.d_model, config.nhead, config.dim_feedforward, config.dropout, batch_first=True)
76
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, config.num_layers)
77
+
78
  self.output_proj = nn.Linear(config.d_model, config.vocab_size)
79
  self.dropout = nn.Dropout(config.dropout)
80
 
81
+ def forward(self,
82
+ tokens: torch.Tensor,
83
+ memory: torch.Tensor,
84
+ tgt_mask: torch.Tensor=None,
85
+ tgt_key_padding_mask: torch.Tensor=None) -> torch.Tensor:
86
+ """
87
+ Passes the input through the decoder
88
+
89
+ Args:
90
+ tokens (torch.Tensor): List of tokens
91
+ memory (torch.Tensor): Memory
92
+ tgt_mask (torch.Tensor, optional): Attention mask. Defaults to None.
93
+ tgt_key_padding_mask (torch.Tensor, optional): Padding mask. Defaults to None.
94
+
95
+ Returns:
96
+ torch.Tensor: Next tokens logits
97
+ """
98
+ # tokens: (Batch, seq_len)
99
  batch_size, seq_len = tokens.shape
100
  device = tokens.device
101
+
102
+ # 1. embeddings do token + posicional
103
+ token_emb = self.embedding(tokens) # Shape (Batch, seq_len, d_model)
104
+ positions = torch.arange(0, seq_len, device=device).unsqueeze(0) # Shape (1, seq_len)
105
+ pos_emb = self.pos_embedding(positions) # Shape: (1, S, D)
106
+
107
+ # 2. Adiciona embeddings e aplica dropout
108
  x = self.dropout(token_emb + pos_emb)
109
+
110
+ # 3. Passa pelo decoder
111
  out = self.transformer_decoder(
112
  tgt=x, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask
113
  )
114
+
115
+ # 4. Passa pela projection
116
  logits = self.output_proj(out)
117
+
118
  return logits
119
 
120
+ # A classe Transformer wrapper também está correta.
121
  class Im2LatexTransformer(PreTrainedModel):
122
  config_class = Im2LatexTransformerConfig
123
 
124
+ def __init__(self, config):
125
+ """
126
+ Builds a Transformer
127
+
128
+ Args:
129
+ config (Im2LatexTransformerConfig): Configuration for the model
130
+ """
131
+ super(Im2LatexTransformer, self).__init__(config)
132
  self.encoder = CNN(config)
133
  self.decoder = Decoder(config)
134
 
135
  def forward(self,
136
  pixel_values: torch.Tensor,
137
  decoder_input_ids: torch.Tensor,
138
+ decoder_padding_mask: torch.Tensor=None) -> torch.Tensor:
139
  """
140
+ Passes the input through the transformer
141
+
142
+ Args:
143
+ pixel_values (torch.Tensor): Input images
144
+ decoder_input_ids (torch.Tensor): Decoder input tokens
145
+ decoder_padding_mask (torch.Tensor, optional): Padding mask for the decoder. Defaults to None.
146
+
147
+ Returns:
148
+ torch.Tensor: Next tokens logits
149
  """
150
+ device = pixel_values.device
151
 
152
+ # 1. Passa pela CNN
153
+ memory = self.encoder(pixel_values)
154
+
155
+ # 2. Prepara o decoder
156
  tgt_mask = None
157
  if decoder_input_ids is not None:
 
158
  seq_len = decoder_input_ids.size(1)
159
  tgt_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device), diagonal=1)
160
 
161
+ # 3. Passa pelo decoder
162
+ logits = self.decoder(decoder_input_ids, memory, tgt_mask, decoder_padding_mask)
 
 
 
 
163
  return logits
164
 
 
165
  @torch.no_grad()
166
+ def generate(self, pixel_values: torch.Tensor, max_length: int = 512, sos_token_id: int = 1, eos_token_id: int = 2):
167
+ """
168
+ Generates a sequence of tokens from the input images
169
+
170
+ Args:
171
+ pixel_values (torch.Tensor): Input images
172
+ max_length (int, optional): Maximum length of the generated sequence. Defaults to 512.
173
+ sos_token_id (int, optional): Start of sequence token ID. Defaults to 1.
174
+ eos_token_id (int, optional): End of sequence token ID. Defaults to 2.
175
+
176
+ Returns:
177
+ torch.Tensor: Generated sequence of tokens
178
+ """
179
+ self.eval() # coloca o modelo em modo de avaliação
180
 
181
  if pixel_values.dim() == 3:
182
  pixel_values = pixel_values.unsqueeze(0)
183
+
184
  pixel_values = pixel_values.to(self.device)
185
+
 
186
  generated_sequence = torch.tensor([[sos_token_id]], dtype=torch.long, device=self.device)
187
 
188
+ for _ in range(max_length):
189
+ logits = self(pixel_values, generated_sequence) # forward do modelo
190
+ last_logits = logits[0, -1, :] # pega a última predição
191
+
192
+ next_token_idx = last_logits.argmax(-1).item() # greedy decoding
193
+
194
+ generated_sequence = torch.cat([
195
+ generated_sequence,
196
+ torch.tensor([[next_token_idx]], dtype=torch.long, device=self.device)
197
+ ], dim=1)
198
+
199
+ if next_token_idx == eos_token_id:
200
+ break
201
+
202
+ return generated_sequence.squeeze(0) # remove dimensão de batch