File size: 5,856 Bytes
8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 8c61cc8 c238e32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import torch
import torch.nn as nn
import torch.nn.functional as F
# Definição de uma camada de embedding com atenção esparsa para texto
class SparseTextEmbedding(nn.Module):
"""
Camada de embedding para texto com atenção multi-cabeça.
Realiza embeddings de tokens de texto e aplica atenção multi-cabeça.
"""
def __init__(self, num_tokens, emb_dim):
super().__init__()
self.embedding = nn.Embedding(num_tokens, emb_dim)
self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True)
def forward(self, x):
x = self.embedding(x)
x, _ = self.attention(x, x, x)
return x
# Processador genérico para transformar entradas numéricas em embeddings
class GenericProcessor(nn.Module):
"""
Processador genérico que transforma entradas numéricas em embeddings.
Utiliza uma camada linear seguida por uma ativação ReLU.
"""
def __init__(self, input_dim, emb_dim):
super().__init__()
self.fc = nn.Linear(input_dim, emb_dim)
def forward(self, x):
return F.relu(self.fc(x))
# Especialista em transformador para domínios específicos
class TransformerExpert(nn.Module):
"""
Especialista em domínio específico usando um encoder Transformer.
Projetado para processar embeddings e realizar tarefas específicas de domínio.
"""
def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
super().__init__()
transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
def forward(self, x):
return self.transformer_encoder(x)
# Decodificador Transformer com atenção cruzada
class TransformerDecoderWithCrossAttention(nn.Module):
"""
Decodificador Transformer com atenção cruzada.
Combina informações de múltiplas fontes e projeta o resultado final.
"""
def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
super().__init__()
transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
self.projection = nn.Linear(emb_dim, emb_dim)
def forward(self, x, memory):
output = self.transformer_decoder(x, memory)
return self.projection(output)
# Modelo principal que incorpora os componentes acima
class EnedinaModel(nn.Module):
"""
Modelo principal: Enedina.
Integra diferentes componentes especializados para processar múltiplos tipos de entrada.
"""
def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024,
num_heads=16, num_layers=12, ff_dim=4096):
super().__init__()
self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim)
self.image_processor = GenericProcessor(image_input_dim, emb_dim)
self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
self.experts = nn.ModuleList([
TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4)
])
self.gate = nn.Linear(emb_dim * 4, 4)
self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)
def forward(self, text_input, image_input, equation_input, diagram_input):
text_emb = self.text_embedding(text_input)
image_emb = self.image_processor(image_input).unsqueeze(1)
equation_emb = self.equation_processor(equation_input).unsqueeze(1)
diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1)
# Estrutura dos especialistas
expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb]
expert_outputs = []
for i, expert in enumerate(self.experts):
expert_output = expert(expert_inputs[i].permute(1, 0, 2))
expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :])
# Combina as saídas dos especialistas
combined_expert_outputs = torch.cat(expert_outputs, dim=-1)
# Calcula os pesos do gate e aplica a combinação ponderada das saídas dos especialistas
gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1)
expert_outputs_stack = torch.stack(expert_outputs, dim=1)
combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1)
# Ajustes de dimensão antes do TransformerDecoder
combined_output = combined_output.unsqueeze(0)
text_emb = text_emb.permute(1, 0, 2)
# Aplica o decodificador Transformer com atenção cruzada
output = self.transformer_decoder(text_emb, combined_output)
return output
# Configuração dos parâmetros do modelo e simulação de entrada para testes
text_num_tokens = 200000
image_input_dim = 2048
equation_input_dim = 1024
diagram_input_dim = 1024
batch_size = 4
text_seq_len = 1000
image_seq_len = 10
equation_seq_len = 5
diagram_seq_len = 5
# Inicializa o modelo
model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)
# Gera entradas simuladas
text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len))
image_input = torch.randn(batch_size, image_input_dim)
equation_input = torch.randn(batch_size, equation_input_dim)
diagram_input = torch.randn(batch_size, diagram_input_dim)
# Executa o modelo com as entradas simuladas
output = model(text_input, image_input, equation_input, diagram_input)
# Verifica a forma da saída
print("A forma de saída do tensor é:", output.shape) |