leitaofilho commited on
Commit
c238e32
·
verified ·
1 Parent(s): 5fa3751

Update EnedinaModel.py

Browse files
Files changed (1) hide show
  1. EnedinaModel.py +90 -49
EnedinaModel.py CHANGED
@@ -3,99 +3,140 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
 
6
- # Define uma camada de embedding para processar entradas de texto
7
- class TextEmbedding(nn.Module):
 
 
 
 
 
8
  def __init__(self, num_tokens, emb_dim):
9
- super(TextEmbedding, self).__init__()
10
  self.embedding = nn.Embedding(num_tokens, emb_dim)
 
11
 
12
  def forward(self, x):
13
- return self.embedding(x)
 
 
14
 
15
 
16
- # Define um processador genérico para transformar entradas numéricas em embeddings
17
  class GenericProcessor(nn.Module):
 
 
 
 
 
18
  def __init__(self, input_dim, emb_dim):
19
- super(GenericProcessor, self).__init__()
20
  self.fc = nn.Linear(input_dim, emb_dim)
21
 
22
  def forward(self, x):
23
  return F.relu(self.fc(x))
24
 
25
 
26
- # Define um decodificador Transformer com atenção cruzada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class TransformerDecoderWithCrossAttention(nn.Module):
 
 
 
 
 
28
  def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
29
- super(TransformerDecoderWithCrossAttention, self).__init__()
30
  transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
31
  self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
32
  self.projection = nn.Linear(emb_dim, emb_dim)
33
 
34
  def forward(self, x, memory):
35
- # Ajusta a entrada para o formato esperado [seq_len, batch_size, emb_dim]
36
- # Aqui, assumimos que `x` e `memory` já estão com as dimensões corretas,
37
- # onde a dimensão seq_len de `x` é determinada pelo modelo anterior que processa a entrada.
38
  output = self.transformer_decoder(x, memory)
39
- # Aplica uma camada linear de projeção para mapear a saída do decodificador
40
- # de volta para o tamanho de embedding desejado.
41
  return self.projection(output)
42
 
43
 
44
- # Define o modelo principal que incorpora os componentes acima
45
  class EnedinaModel(nn.Module):
46
- def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim,
47
- emb_dim=512, num_heads=8, num_layers=6, ff_dim=2048):
48
- super(EnedinaModel, self).__init__()
49
- # Inicializa os componentes do modelo
50
- self.text_embedding = TextEmbedding(text_num_tokens, emb_dim)
 
 
 
 
51
  self.image_processor = GenericProcessor(image_input_dim, emb_dim)
52
  self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
53
  self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
 
 
 
 
54
  self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)
55
 
56
  def forward(self, text_input, image_input, equation_input, diagram_input):
57
- # Verifica as dimensões das entradas
58
- assert text_input.dim() == 2, "A entrada de texto deve ter dimensões (batch_size, seq_len)"
59
- assert image_input.dim() == 2, "A entrada de imagem deve ter dimensões (batch_size, image_input_dim)"
60
- assert equation_input.dim() == 2, "A entrada de equação deve ter dimensões (batch_size, equation_input_dim)"
61
- assert diagram_input.dim() == 2, "A entrada de diagrama deve ter dimensões (batch_size, diagram_input_dim)"
62
-
63
- # Processa as entradas através dos respectivos componentes
64
  text_emb = self.text_embedding(text_input)
65
- image_emb = self.image_processor(image_input)
66
- equation_emb = self.equation_processor(equation_input)
67
- diagram_emb = self.diagram_processor(diagram_input)
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Ajusta as dimensões dos embeddings para permitir a concatenação
70
- # Importante: Ajusta para que todos os embeddings sejam 3D [batch_size, seq_len, emb_dim]
71
- # Texto já é [batch_size, seq_len, emb_dim] por padrão
72
- # Para imagem, equação e diagrama, adiciona-se uma dimensão seq_len fictícia
73
- image_emb = image_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
74
- equation_emb = equation_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
75
- diagram_emb = diagram_emb.unsqueeze(1) # Agora é [batch_size, 1, emb_dim]
76
 
77
- # Concatena os embeddings
78
- combined = torch.cat([text_emb, image_emb, equation_emb, diagram_emb], dim=1)
 
79
 
80
- # Aplica o decodificador Transformer ao embedding combinado
81
- output = self.transformer_decoder(combined, combined)
82
 
83
  return output
84
 
85
 
86
- # Configuração e simulação de entrada
87
- text_num_tokens = 20000
88
- image_input_dim = 512
89
- equation_input_dim = 256
90
- diagram_input_dim = 128
91
  batch_size = 4
92
- seq_len = 10
 
 
 
93
 
94
  # Inicializa o modelo
95
  model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)
96
 
97
  # Gera entradas simuladas
98
- text_input = torch.randint(0, text_num_tokens, (batch_size, seq_len))
99
  image_input = torch.randn(batch_size, image_input_dim)
100
  equation_input = torch.randn(batch_size, equation_input_dim)
101
  diagram_input = torch.randn(batch_size, diagram_input_dim)
@@ -103,5 +144,5 @@ diagram_input = torch.randn(batch_size, diagram_input_dim)
103
  # Executa o modelo com as entradas simuladas
104
  output = model(text_input, image_input, equation_input, diagram_input)
105
 
106
- # Imprime a forma da saída para verificação
107
- print(output.shape)
 
3
  import torch.nn.functional as F
4
 
5
 
6
+ # Definição de uma camada de embedding com atenção esparsa para texto
7
+ class SparseTextEmbedding(nn.Module):
8
+ """
9
+ Camada de embedding para texto com atenção multi-cabeça.
10
+ Realiza embeddings de tokens de texto e aplica atenção multi-cabeça.
11
+ """
12
+
13
  def __init__(self, num_tokens, emb_dim):
14
+ super().__init__()
15
  self.embedding = nn.Embedding(num_tokens, emb_dim)
16
+ self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True)
17
 
18
  def forward(self, x):
19
+ x = self.embedding(x)
20
+ x, _ = self.attention(x, x, x)
21
+ return x
22
 
23
 
24
+ # Processador genérico para transformar entradas numéricas em embeddings
25
  class GenericProcessor(nn.Module):
26
+ """
27
+ Processador genérico que transforma entradas numéricas em embeddings.
28
+ Utiliza uma camada linear seguida por uma ativação ReLU.
29
+ """
30
+
31
  def __init__(self, input_dim, emb_dim):
32
+ super().__init__()
33
  self.fc = nn.Linear(input_dim, emb_dim)
34
 
35
  def forward(self, x):
36
  return F.relu(self.fc(x))
37
 
38
 
39
+ # Especialista em transformador para domínios específicos
40
+ class TransformerExpert(nn.Module):
41
+ """
42
+ Especialista em domínio específico usando um encoder Transformer.
43
+ Projetado para processar embeddings e realizar tarefas específicas de domínio.
44
+ """
45
+
46
+ def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
47
+ super().__init__()
48
+ transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
49
+ self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
50
+
51
+ def forward(self, x):
52
+ return self.transformer_encoder(x)
53
+
54
+
55
+ # Decodificador Transformer com atenção cruzada
56
  class TransformerDecoderWithCrossAttention(nn.Module):
57
+ """
58
+ Decodificador Transformer com atenção cruzada.
59
+ Combina informações de múltiplas fontes e projeta o resultado final.
60
+ """
61
+
62
  def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
63
+ super().__init__()
64
  transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
65
  self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
66
  self.projection = nn.Linear(emb_dim, emb_dim)
67
 
68
  def forward(self, x, memory):
 
 
 
69
  output = self.transformer_decoder(x, memory)
 
 
70
  return self.projection(output)
71
 
72
 
73
+ # Modelo principal que incorpora os componentes acima
74
  class EnedinaModel(nn.Module):
75
+ """
76
+ Modelo principal: Enedina.
77
+ Integra diferentes componentes especializados para processar múltiplos tipos de entrada.
78
+ """
79
+
80
+ def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024,
81
+ num_heads=16, num_layers=12, ff_dim=4096):
82
+ super().__init__()
83
+ self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim)
84
  self.image_processor = GenericProcessor(image_input_dim, emb_dim)
85
  self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
86
  self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
87
+ self.experts = nn.ModuleList([
88
+ TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4)
89
+ ])
90
+ self.gate = nn.Linear(emb_dim * 4, 4)
91
  self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)
92
 
93
  def forward(self, text_input, image_input, equation_input, diagram_input):
 
 
 
 
 
 
 
94
  text_emb = self.text_embedding(text_input)
95
+ image_emb = self.image_processor(image_input).unsqueeze(1)
96
+ equation_emb = self.equation_processor(equation_input).unsqueeze(1)
97
+ diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1)
98
+
99
+ # Estrutura dos especialistas
100
+ expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb]
101
+ expert_outputs = []
102
+ for i, expert in enumerate(self.experts):
103
+ expert_output = expert(expert_inputs[i].permute(1, 0, 2))
104
+ expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :])
105
+
106
+ # Combina as saídas dos especialistas
107
+ combined_expert_outputs = torch.cat(expert_outputs, dim=-1)
108
 
109
+ # Calcula os pesos do gate e aplica a combinação ponderada das saídas dos especialistas
110
+ gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1)
111
+ expert_outputs_stack = torch.stack(expert_outputs, dim=1)
112
+ combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1)
 
 
 
113
 
114
+ # Ajustes de dimensão antes do TransformerDecoder
115
+ combined_output = combined_output.unsqueeze(0)
116
+ text_emb = text_emb.permute(1, 0, 2)
117
 
118
+ # Aplica o decodificador Transformer com atenção cruzada
119
+ output = self.transformer_decoder(text_emb, combined_output)
120
 
121
  return output
122
 
123
 
124
+ # Configuração dos parâmetros do modelo e simulação de entrada para testes
125
+ text_num_tokens = 200000
126
+ image_input_dim = 2048
127
+ equation_input_dim = 1024
128
+ diagram_input_dim = 1024
129
  batch_size = 4
130
+ text_seq_len = 1000
131
+ image_seq_len = 10
132
+ equation_seq_len = 5
133
+ diagram_seq_len = 5
134
 
135
  # Inicializa o modelo
136
  model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)
137
 
138
  # Gera entradas simuladas
139
+ text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len))
140
  image_input = torch.randn(batch_size, image_input_dim)
141
  equation_input = torch.randn(batch_size, equation_input_dim)
142
  diagram_input = torch.randn(batch_size, diagram_input_dim)
 
144
  # Executa o modelo com as entradas simuladas
145
  output = model(text_input, image_input, equation_input, diagram_input)
146
 
147
+ # Verifica a forma da saída
148
+ print("A forma de saída do tensor é:", output.shape)