Danielfonseca1212 commited on
Commit
c9959a3
Β·
verified Β·
1 Parent(s): 0b037bd

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +236 -0
model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ relgnn/model.py
3
+ Core RelGNN β€” AtenΓ§Γ£o sobre Rotas AtΓ΄micas (sem grafo estΓ‘tico).
4
+
5
+ Arquitetura:
6
+ 1. TableEncoder: embeddings por tabela via MLP sobre features numΓ©ricas
7
+ 2. RouteAggregator: attention ao longo de cada rota (sequΓͺncia de tabelas)
8
+ 3. HierarchicalAgg: agrega mΓΊltiplas rotas com pesos aprendidos
9
+ 4. FraudHead: classificador binΓ‘rio final
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Tuple, Optional
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from data.routes import AtomicRoute
20
+
21
+
22
+ # ─── CONFIG ───────────────────────────────────────────────────────────────────
23
+
24
+ @dataclass
25
+ class RelGNNConfig:
26
+ hidden_dim: int = 64
27
+ num_epochs: int = 50
28
+ learning_rate: float = 1e-3
29
+ dropout: float = 0.2
30
+ num_heads: int = 4
31
+ seed: int = 42
32
+
33
+
34
+ # ─── TABLE ENCODER ────────────────────────────────────────────────────────────
35
+
36
+ class TableEncoder(nn.Module):
37
+ """
38
+ Codifica as features de uma tabela em um embedding de tamanho `hidden_dim`.
39
+ Opera direto nas colunas numΓ©ricas β€” sem conversΓ£o para grafo.
40
+ """
41
+ def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.2):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(input_dim, hidden_dim * 2),
45
+ nn.LayerNorm(hidden_dim * 2),
46
+ nn.ReLU(),
47
+ nn.Dropout(dropout),
48
+ nn.Linear(hidden_dim * 2, hidden_dim),
49
+ nn.LayerNorm(hidden_dim),
50
+ nn.ReLU(),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return self.net(x)
55
+
56
+
57
+ # ─── ROUTE ATTENTION ──────────────────────────────────────────────────────────
58
+
59
+ class RouteAttention(nn.Module):
60
+ """
61
+ Mecanismo de atenΓ§Γ£o sobre uma Rota AtΓ΄mica.
62
+ Recebe sequΓͺncia de embeddings [h1, h2, ..., hK] (K = n_hops + 1)
63
+ e retorna um embedding agregado representando a rota.
64
+
65
+ Implementa atenΓ§Γ£o scaled-dot-product entre os hops.
66
+ """
67
+ def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.2):
68
+ super().__init__()
69
+ self.attn = nn.MultiheadAttention(
70
+ embed_dim=hidden_dim,
71
+ num_heads=num_heads,
72
+ dropout=dropout,
73
+ batch_first=True,
74
+ )
75
+ self.norm = nn.LayerNorm(hidden_dim)
76
+ self.mlp = nn.Sequential(
77
+ nn.Linear(hidden_dim, hidden_dim * 2),
78
+ nn.ReLU(),
79
+ nn.Dropout(dropout),
80
+ nn.Linear(hidden_dim * 2, hidden_dim),
81
+ )
82
+
83
+ def forward(self, hop_embeddings: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
84
+ """
85
+ Args:
86
+ hop_embeddings: [batch, n_hops, hidden_dim]
87
+ Returns:
88
+ route_emb: [batch, hidden_dim] β€” representaΓ§Γ£o da rota
89
+ alpha: [batch, n_hops] β€” pesos de atenΓ§Γ£o por hop
90
+ """
91
+ # Self-attention entre os hops da rota
92
+ attn_out, alpha = self.attn(hop_embeddings, hop_embeddings, hop_embeddings)
93
+
94
+ # Residual + norm
95
+ attn_out = self.norm(attn_out + hop_embeddings)
96
+
97
+ # Agrega via mean-pooling ponderado (ΓΊltimo hop = entidade alvo)
98
+ # O primeiro token (tabela alvo) agrega informaΓ§Γ΅es dos vizinhos
99
+ route_emb = attn_out[:, 0, :] # [batch, hidden_dim]
100
+ route_emb = route_emb + self.mlp(route_emb)
101
+
102
+ alpha_weights = alpha.mean(dim=1)[:, 0, :] # [batch, n_hops]
103
+ return route_emb, alpha_weights
104
+
105
+
106
+ # ─── HIERARCHICAL ROUTE AGGREGATOR ───────────────────────────────────────────
107
+
108
+ class HierarchicalRouteAgg(nn.Module):
109
+ """
110
+ Agrega embeddings de mΓΊltiplas rotas com pesos aprendidos.
111
+ Cada rota contribui de forma diferente para a prediΓ§Γ£o final.
112
+ """
113
+ def __init__(self, hidden_dim: int, num_routes: int):
114
+ super().__init__()
115
+ self.route_weights = nn.Parameter(torch.ones(num_routes))
116
+ self.output_proj = nn.Linear(hidden_dim, hidden_dim)
117
+
118
+ def forward(self, route_embeddings: List[torch.Tensor]) -> torch.Tensor:
119
+ """
120
+ Args:
121
+ route_embeddings: lista de [batch, hidden_dim], uma por rota
122
+ Returns:
123
+ agg: [batch, hidden_dim]
124
+ """
125
+ stacked = torch.stack(route_embeddings, dim=1) # [batch, R, hidden]
126
+ weights = F.softmax(self.route_weights, dim=0) # [R]
127
+ weighted = (stacked * weights.unsqueeze(0).unsqueeze(-1)).sum(dim=1)
128
+ return self.output_proj(weighted)
129
+
130
+
131
+ # ─── FRAUD HEAD ───────────────────────────────────────────────────────────────
132
+
133
+ class FraudHead(nn.Module):
134
+ def __init__(self, hidden_dim: int, dropout: float = 0.2):
135
+ super().__init__()
136
+ self.net = nn.Sequential(
137
+ nn.Linear(hidden_dim, hidden_dim // 2),
138
+ nn.ReLU(),
139
+ nn.Dropout(dropout),
140
+ nn.Linear(hidden_dim // 2, 1),
141
+ )
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ return self.net(x).squeeze(-1) # [batch]
145
+
146
+
147
+ # ─── RELGNN ───────────────────────────────────────────────────────────────────
148
+
149
+ class RelGNN(nn.Module):
150
+ """
151
+ RelGNN completo.
152
+
153
+ Fluxo:
154
+ tabelas SQL
155
+ β†’ TableEncoder (por tabela)
156
+ β†’ RouteAttention (por rota atΓ΄mica)
157
+ β†’ HierarchicalRouteAgg
158
+ β†’ FraudHead
159
+ β†’ sigmoid(logit) = P(fraude)
160
+ """
161
+
162
+ def __init__(self, config: RelGNNConfig):
163
+ super().__init__()
164
+ self.config = config
165
+ torch.manual_seed(config.seed)
166
+
167
+ def build(self, feature_dims: Dict[str, int], routes: List[AtomicRoute]):
168
+ """Instancia os mΓ³dulos apΓ³s conhecer as dimensΓ΅es das features."""
169
+ H = self.config.hidden_dim
170
+ D = self.config.dropout
171
+
172
+ self.table_encoders = nn.ModuleDict({
173
+ table: TableEncoder(dim, H, D)
174
+ for table, dim in feature_dims.items()
175
+ })
176
+
177
+ self.route_attns = nn.ModuleList([
178
+ RouteAttention(H, self.config.num_heads, D)
179
+ for _ in routes
180
+ ])
181
+
182
+ self.hierarchical = HierarchicalRouteAgg(H, len(routes))
183
+ self.fraud_head = FraudHead(H, D)
184
+ self.routes = routes
185
+
186
+ def forward(
187
+ self,
188
+ table_features: Dict[str, torch.Tensor],
189
+ ) -> Tuple[torch.Tensor, Dict]:
190
+ """
191
+ Args:
192
+ table_features: {table_name: [batch, feature_dim]}
193
+ Returns:
194
+ logits: [batch]
195
+ attention_info: dict com pesos de atenΓ§Γ£o por rota
196
+ """
197
+ # 1. Encoder por tabela
198
+ table_embs = {
199
+ table: encoder(table_features[table])
200
+ for table, encoder in self.table_encoders.items()
201
+ if table in table_features
202
+ }
203
+
204
+ # 2. Attention por rota atΓ΄mica
205
+ route_embs = []
206
+ attention_info = {}
207
+
208
+ for i, (route, attn_module) in enumerate(zip(self.routes, self.route_attns)):
209
+ # Coleta embeddings das tabelas na rota
210
+ available = [t for t in route.path if t in table_embs]
211
+ if len(available) < 2:
212
+ # Usa embedding da tabela alvo repetido se rota incompleta
213
+ e = table_embs.get(route.path[0], list(table_embs.values())[0])
214
+ route_embs.append(e)
215
+ continue
216
+
217
+ hop_list = [table_embs[t] for t in available]
218
+ hop_tensor = torch.stack(hop_list, dim=1) # [batch, K, H]
219
+
220
+ route_emb, alpha = attn_module(hop_tensor)
221
+ route_embs.append(route_emb)
222
+ attention_info[f"route_{i}"] = alpha.detach().cpu().numpy()
223
+
224
+ # 3. Agrega rotas hierarquicamente
225
+ agg = self.hierarchical(route_embs)
226
+
227
+ # 4. Classificador de fraude
228
+ logits = self.fraud_head(agg)
229
+
230
+ return logits, attention_info
231
+
232
+ def fit(self, tables, routes, log_fn=print, progress_fn=None):
233
+ """Wrapper de treinamento completo."""
234
+ from relgnn.trainer import Trainer
235
+ trainer = Trainer(self, self.config)
236
+ return trainer.fit(tables, routes, log_fn=log_fn, progress_fn=progress_fn)