| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | class SparseTextEmbedding(nn.Module): |
| | """ |
| | Camada de embedding para texto com atenção multi-cabeça. |
| | Realiza embeddings de tokens de texto e aplica atenção multi-cabeça. |
| | """ |
| |
|
| | def __init__(self, num_tokens, emb_dim): |
| | super().__init__() |
| | self.embedding = nn.Embedding(num_tokens, emb_dim) |
| | self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True) |
| |
|
| | def forward(self, x): |
| | x = self.embedding(x) |
| | x, _ = self.attention(x, x, x) |
| | return x |
| |
|
| |
|
| | |
| | class GenericProcessor(nn.Module): |
| | """ |
| | Processador genérico que transforma entradas numéricas em embeddings. |
| | Utiliza uma camada linear seguida por uma ativação ReLU. |
| | """ |
| |
|
| | def __init__(self, input_dim, emb_dim): |
| | super().__init__() |
| | self.fc = nn.Linear(input_dim, emb_dim) |
| |
|
| | def forward(self, x): |
| | return F.relu(self.fc(x)) |
| |
|
| |
|
| | |
| | class TransformerExpert(nn.Module): |
| | """ |
| | Especialista em domínio específico usando um encoder Transformer. |
| | Projetado para processar embeddings e realizar tarefas específicas de domínio. |
| | """ |
| |
|
| | def __init__(self, emb_dim, num_heads, num_layers, ff_dim): |
| | super().__init__() |
| | transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) |
| | self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers) |
| |
|
| | def forward(self, x): |
| | return self.transformer_encoder(x) |
| |
|
| |
|
| | |
| | class TransformerDecoderWithCrossAttention(nn.Module): |
| | """ |
| | Decodificador Transformer com atenção cruzada. |
| | Combina informações de múltiplas fontes e projeta o resultado final. |
| | """ |
| |
|
| | def __init__(self, emb_dim, num_heads, num_layers, ff_dim): |
| | super().__init__() |
| | transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) |
| | self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers) |
| | self.projection = nn.Linear(emb_dim, emb_dim) |
| |
|
| | def forward(self, x, memory): |
| | output = self.transformer_decoder(x, memory) |
| | return self.projection(output) |
| |
|
| |
|
| | |
| | class EnedinaModel(nn.Module): |
| | """ |
| | Modelo principal: Enedina. |
| | Integra diferentes componentes especializados para processar múltiplos tipos de entrada. |
| | """ |
| |
|
| | def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024, |
| | num_heads=16, num_layers=12, ff_dim=4096): |
| | super().__init__() |
| | self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim) |
| | self.image_processor = GenericProcessor(image_input_dim, emb_dim) |
| | self.equation_processor = GenericProcessor(equation_input_dim, emb_dim) |
| | self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim) |
| | self.experts = nn.ModuleList([ |
| | TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4) |
| | ]) |
| | self.gate = nn.Linear(emb_dim * 4, 4) |
| | self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim) |
| |
|
| | def forward(self, text_input, image_input, equation_input, diagram_input): |
| | text_emb = self.text_embedding(text_input) |
| | image_emb = self.image_processor(image_input).unsqueeze(1) |
| | equation_emb = self.equation_processor(equation_input).unsqueeze(1) |
| | diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1) |
| |
|
| | |
| | expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb] |
| | expert_outputs = [] |
| | for i, expert in enumerate(self.experts): |
| | expert_output = expert(expert_inputs[i].permute(1, 0, 2)) |
| | expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :]) |
| |
|
| | |
| | combined_expert_outputs = torch.cat(expert_outputs, dim=-1) |
| |
|
| | |
| | gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1) |
| | expert_outputs_stack = torch.stack(expert_outputs, dim=1) |
| | combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1) |
| |
|
| | |
| | combined_output = combined_output.unsqueeze(0) |
| | text_emb = text_emb.permute(1, 0, 2) |
| |
|
| | |
| | output = self.transformer_decoder(text_emb, combined_output) |
| |
|
| | return output |
| |
|
| |
|
| | |
| | text_num_tokens = 200000 |
| | image_input_dim = 2048 |
| | equation_input_dim = 1024 |
| | diagram_input_dim = 1024 |
| | batch_size = 4 |
| | text_seq_len = 1000 |
| | image_seq_len = 10 |
| | equation_seq_len = 5 |
| | diagram_seq_len = 5 |
| |
|
| | |
| | model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim) |
| |
|
| | |
| | text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len)) |
| | image_input = torch.randn(batch_size, image_input_dim) |
| | equation_input = torch.randn(batch_size, equation_input_dim) |
| | diagram_input = torch.randn(batch_size, diagram_input_dim) |
| |
|
| | |
| | output = model(text_input, image_input, equation_input, diagram_input) |
| |
|
| | |
| | print("A forma de saída do tensor é:", output.shape) |