File size: 5,856 Bytes
8c61cc8
 
 
 
 
c238e32
 
 
 
 
 
 
8c61cc8
c238e32
8c61cc8
c238e32
8c61cc8
 
c238e32
 
 
8c61cc8
 
c238e32
8c61cc8
c238e32
 
 
 
 
8c61cc8
c238e32
8c61cc8
 
 
 
 
 
c238e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c61cc8
c238e32
 
 
 
 
8c61cc8
c238e32
8c61cc8
 
 
 
 
 
 
 
 
c238e32
8c61cc8
c238e32
 
 
 
 
 
 
 
 
8c61cc8
 
 
c238e32
 
 
 
8c61cc8
 
 
 
c238e32
 
 
 
 
 
 
 
 
 
 
 
 
8c61cc8
c238e32
 
 
 
8c61cc8
c238e32
 
 
8c61cc8
c238e32
 
8c61cc8
 
 
 
c238e32
 
 
 
 
8c61cc8
c238e32
 
 
 
8c61cc8
 
 
 
 
c238e32
8c61cc8
 
 
 
 
 
 
c238e32
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
import torch.nn.functional as F


# Definição de uma camada de embedding com atenção esparsa para texto
class SparseTextEmbedding(nn.Module):
    """
    Camada de embedding para texto com atenção multi-cabeça.
    Realiza embeddings de tokens de texto e aplica atenção multi-cabeça.
    """

    def __init__(self, num_tokens, emb_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_tokens, emb_dim)
        self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.attention(x, x, x)
        return x


# Processador genérico para transformar entradas numéricas em embeddings
class GenericProcessor(nn.Module):
    """
    Processador genérico que transforma entradas numéricas em embeddings.
    Utiliza uma camada linear seguida por uma ativação ReLU.
    """

    def __init__(self, input_dim, emb_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, emb_dim)

    def forward(self, x):
        return F.relu(self.fc(x))


# Especialista em transformador para domínios específicos
class TransformerExpert(nn.Module):
    """
    Especialista em domínio específico usando um encoder Transformer.
    Projetado para processar embeddings e realizar tarefas específicas de domínio.
    """

    def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
        super().__init__()
        transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)

    def forward(self, x):
        return self.transformer_encoder(x)


# Decodificador Transformer com atenção cruzada
class TransformerDecoderWithCrossAttention(nn.Module):
    """
    Decodificador Transformer com atenção cruzada.
    Combina informações de múltiplas fontes e projeta o resultado final.
    """

    def __init__(self, emb_dim, num_heads, num_layers, ff_dim):
        super().__init__()
        transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers)
        self.projection = nn.Linear(emb_dim, emb_dim)

    def forward(self, x, memory):
        output = self.transformer_decoder(x, memory)
        return self.projection(output)


# Modelo principal que incorpora os componentes acima
class EnedinaModel(nn.Module):
    """
    Modelo principal: Enedina.
    Integra diferentes componentes especializados para processar múltiplos tipos de entrada.
    """

    def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024,
                 num_heads=16, num_layers=12, ff_dim=4096):
        super().__init__()
        self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim)
        self.image_processor = GenericProcessor(image_input_dim, emb_dim)
        self.equation_processor = GenericProcessor(equation_input_dim, emb_dim)
        self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim)
        self.experts = nn.ModuleList([
            TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4)
        ])
        self.gate = nn.Linear(emb_dim * 4, 4)
        self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim)

    def forward(self, text_input, image_input, equation_input, diagram_input):
        text_emb = self.text_embedding(text_input)
        image_emb = self.image_processor(image_input).unsqueeze(1)
        equation_emb = self.equation_processor(equation_input).unsqueeze(1)
        diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1)

        # Estrutura dos especialistas
        expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb]
        expert_outputs = []
        for i, expert in enumerate(self.experts):
            expert_output = expert(expert_inputs[i].permute(1, 0, 2))
            expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :])

        # Combina as saídas dos especialistas
        combined_expert_outputs = torch.cat(expert_outputs, dim=-1)

        # Calcula os pesos do gate e aplica a combinação ponderada das saídas dos especialistas
        gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1)
        expert_outputs_stack = torch.stack(expert_outputs, dim=1)
        combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1)

        # Ajustes de dimensão antes do TransformerDecoder
        combined_output = combined_output.unsqueeze(0)
        text_emb = text_emb.permute(1, 0, 2)

        # Aplica o decodificador Transformer com atenção cruzada
        output = self.transformer_decoder(text_emb, combined_output)

        return output


# Configuração dos parâmetros do modelo e simulação de entrada para testes
text_num_tokens = 200000
image_input_dim = 2048
equation_input_dim = 1024
diagram_input_dim = 1024
batch_size = 4
text_seq_len = 1000
image_seq_len = 10
equation_seq_len = 5
diagram_seq_len = 5

# Inicializa o modelo
model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim)

# Gera entradas simuladas
text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len))
image_input = torch.randn(batch_size, image_input_dim)
equation_input = torch.randn(batch_size, equation_input_dim)
diagram_input = torch.randn(batch_size, diagram_input_dim)

# Executa o modelo com as entradas simuladas
output = model(text_input, image_input, equation_input, diagram_input)

# Verifica a forma da saída
print("A forma de saída do tensor é:", output.shape)