leitaofilho commited on
Commit
8c61cc8
·
verified ·
1 Parent(s): 1266188

Upload 2 files

Browse files
Files changed (2) hide show
  1. EnedinaModel.py +107 -0
  2. requirements.txt +1 -0
EnedinaModel.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # Define uma camada de embedding para processar entradas de texto
7
+ class TextEmbedding(nn.Module):
8
+ def __init__(self, num_tokens, emb_dim):
9
+ super(TextEmbedding, self).__init__()
10
+ self.embedding = nn.Embedding(num_tokens, emb_dim)
11
+
12
+ def forward(self, x):
13
+ return self.embedding(x)
14
+
15
+
16
+ # Define um processador genérico para transformar entradas numéricas em embeddings
17
+ class GenericProcessor(nn.Module):
18
+ def __init__(self, input_dim, emb_dim):
19
+ super(GenericProcessor, self).__init__()
20
+ self.fc = nn.Linear(input_dim, emb_dim)
21
+
22
+ def forward(self, x):
23
+ return F.relu(self.fc(x))
24
+
25
+
26
+ # Define um decodificador Transformer com atenção cruzada
27
+ class TransformerDecoderWithCrossAttention(nn.Module):
28
+ def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
29
+ super(TransformerDecoderWithCrossAttention, self).__init__()
30
+ transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
31
+ self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
32
+ self.projection = nn.Linear(emb_dim, emb_dim)
33
+
34
+ def forward(self, x, memory):
35
+ # Ajusta a entrada para o formato esperado [seq_len, batch_size, emb_dim]
36
+ # Aqui, assumimos que `x` e `memory` já estão com as dimensões corretas,
37
+ # onde a dimensão seq_len de `x` é determinada pelo modelo anterior que processa a entrada.
38
+ output = self.transformer_decoder(x, memory)
39
+ # Aplica uma camada linear de projeção para mapear a saída do decodificador
40
+ # de volta para o tamanho de embedding desejado.
41
+ return self.projection(output)
42
+
43
+
44
+ # Define o modelo principal que incorpora os componentes acima
45
+ class EnedinaModel(nn.Module):
46
+ def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim,
47
+ emb_dim=512, num_heads=8, num_layers=6, ff_dim=2048):
48
+ super(EnedinaModel, self).__init__()
49
+ # Inicializa os componentes do modelo
50
+ self.text_embedding = TextEmbedding(text_num_tokens, emb_dim)
51
+ self.image_processor = GenericProcessor(image_input_dim, emb_dim)
52
+ self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
53
+ self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
54
+ self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)
55
+
56
+ def forward(self, text_input, image_input, equation_input, diagram_input):
57
+ # Verifica as dimensões das entradas
58
+ assert text_input.dim() == 2, "A entrada de texto deve ter dimensões (batch_size, seq_len)"
59
+ assert image_input.dim() == 2, "A entrada de imagem deve ter dimensões (batch_size, image_input_dim)"
60
+ assert equation_input.dim() == 2, "A entrada de equação deve ter dimensões (batch_size, equation_input_dim)"
61
+ assert diagram_input.dim() == 2, "A entrada de diagrama deve ter dimensões (batch_size, diagram_input_dim)"
62
+
63
+ # Processa as entradas através dos respectivos componentes
64
+ text_emb = self.text_embedding(text_input)
65
+ image_emb = self.image_processor(image_input)
66
+ equation_emb = self.equation_processor(equation_input)
67
+ diagram_emb = self.diagram_processor(diagram_input)
68
+
69
+ # Ajusta as dimensões dos embeddings para permitir a concatenação
70
+ # Importante: Ajusta para que todos os embeddings sejam 3D [batch_size, seq_len, emb_dim]
71
+ # Texto já é [batch_size, seq_len, emb_dim] por padrão
72
+ # Para imagem, equação e diagrama, adiciona-se uma dimensão seq_len fictícia
73
+ image_emb = image_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
74
+ equation_emb = equation_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
75
+ diagram_emb = diagram_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
76
+
77
+ # Concatena os embeddings
78
+ combined = torch.cat([text_emb, image_emb, equation_emb, diagram_emb], dim=1)
79
+
80
+ # Aplica o decodificador Transformer ao embedding combinado
81
+ output = self.transformer_decoder(combined, combined)
82
+
83
+ return output
84
+
85
+
86
+ # Configuração e simulação de entrada
87
+ text_num_tokens = 20000
88
+ image_input_dim = 512
89
+ equation_input_dim = 256
90
+ diagram_input_dim = 128
91
+ batch_size = 4
92
+ seq_len = 10
93
+
94
+ # Inicializa o modelo
95
+ model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)
96
+
97
+ # Gera entradas simuladas
98
+ text_input = torch.randint(0, text_num_tokens, (batch_size, seq_len))
99
+ image_input = torch.randn(batch_size, image_input_dim)
100
+ equation_input = torch.randn(batch_size, equation_input_dim)
101
+ diagram_input = torch.randn(batch_size, diagram_input_dim)
102
+
103
+ # Executa o modelo com as entradas simuladas
104
+ output = model(text_input, image_input, equation_input, diagram_input)
105
+
106
+ # Imprime a forma da saída para verificação
107
+ print(output.shape)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch~=2.2.1