Danielfonseca1212 commited on
Commit
5aff06a
Β·
verified Β·
1 Parent(s): b903062

Create dominant model.py

Browse files
Files changed (1) hide show
  1. dominant model.py +270 -0
dominant model.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dominant_model.py β€” DOMINANT: Deep Anomaly Detection on Attributed Networks
2
+ # Paper: Ding et al., IJCAI 2019
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from sklearn.metrics import (
8
+ roc_auc_score, average_precision_score,
9
+ f1_score, precision_score, recall_score
10
+ )
11
+
12
+
13
+ # ──────────────────────────────────────────────────────────────
14
+ # GCN LAYER β€” implementaΓ§Γ£o manual (sem torch-sparse)
15
+ # ──────────────────────────────────────────────────────────────
16
+ class GCNLayer(nn.Module):
17
+ def __init__(self, in_dim, out_dim, bias=True):
18
+ super().__init__()
19
+ self.W = nn.Linear(in_dim, out_dim, bias=bias)
20
+
21
+ def forward(self, x, edge_index, edge_weight, n_nos):
22
+ # AgregaΓ§Γ£o de vizinhos: A_norm @ X @ W
23
+ h = self.W(x) # [N, out]
24
+ row, col = edge_index
25
+ # Scatter weighted sum
26
+ agg = torch.zeros_like(h)
27
+ agg.scatter_add_(0, col.unsqueeze(1).expand_as(h[row]),
28
+ h[row] * edge_weight.unsqueeze(1))
29
+ return agg
30
+
31
+
32
+ # ──────────────────────────────────────────────────────────────
33
+ # ENCODER β€” GCN compartilhado
34
+ # ──────────────────────────────────────────────────────────────
35
+ class GCNEncoder(nn.Module):
36
+ def __init__(self, in_dim, hidden_dim, embed_dim, dropout=0.3):
37
+ super().__init__()
38
+ self.gc1 = GCNLayer(in_dim, hidden_dim)
39
+ self.gc2 = GCNLayer(hidden_dim, embed_dim)
40
+ self.dropout = dropout
41
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
42
+
43
+ def forward(self, x, edge_index, edge_weight, n_nos):
44
+ h = self.gc1(x, edge_index, edge_weight, n_nos)
45
+ h = self.bn1(F.relu(h))
46
+ h = F.dropout(h, p=self.dropout, training=self.training)
47
+ h = self.gc2(h, edge_index, edge_weight, n_nos)
48
+ return h # [N, embed_dim]
49
+
50
+
51
+ # ──────────────────────────────────────────────────────────────
52
+ # ATTRIBUTE DECODER β€” reconstrΓ³i features originais
53
+ # ──────────────────────────────────────────────────────────────
54
+ class AttributeDecoder(nn.Module):
55
+ def __init__(self, embed_dim, hidden_dim, out_dim, dropout=0.3):
56
+ super().__init__()
57
+ self.gc1 = GCNLayer(embed_dim, hidden_dim)
58
+ self.gc2 = GCNLayer(hidden_dim, out_dim)
59
+ self.dropout = dropout
60
+
61
+ def forward(self, z, edge_index, edge_weight, n_nos):
62
+ h = F.relu(self.gc1(z, edge_index, edge_weight, n_nos))
63
+ h = F.dropout(h, p=self.dropout, training=self.training)
64
+ return self.gc2(h, edge_index, edge_weight, n_nos)
65
+
66
+
67
+ # ──────────────────────────────────────────────────────────────
68
+ # STRUCTURE DECODER β€” reconstrΓ³i adjacΓͺncia via produto interno
69
+ # ──────────────────────────────────────────────────────────────
70
+ class StructureDecoder(nn.Module):
71
+ def __init__(self, embed_dim, hidden_dim, dropout=0.3):
72
+ super().__init__()
73
+ self.gc1 = GCNLayer(embed_dim, hidden_dim)
74
+ self.dropout = dropout
75
+
76
+ def forward(self, z, edge_index, edge_weight, n_nos):
77
+ h = F.relu(self.gc1(z, edge_index, edge_weight, n_nos))
78
+ h = F.dropout(h, p=self.dropout, training=self.training)
79
+ # ReconstrΓ³i A via produto interno: sigmoid(Z @ Z^T)
80
+ # Para eficiΓͺncia, sΓ³ calcula para arestas existentes
81
+ row, col = edge_index
82
+ scores = (h[row] * h[col]).sum(dim=1)
83
+ return torch.sigmoid(scores), h
84
+
85
+
86
+ # ──────────────────────────────────────────────────────────────
87
+ # DOMINANT COMPLETO
88
+ # ──────────────────────────────────────────────────────────────
89
+ class DOMINANT(nn.Module):
90
+ """
91
+ Deep Anomaly Detection on Attributed Networks.
92
+ Ding et al., IJCAI 2019.
93
+
94
+ Loss = Ξ± Γ— L_structure + (1-Ξ±) Γ— L_attribute
95
+ Anomaly Score = Ξ± Γ— err_struct(v) + (1-Ξ±) Γ— err_attr(v)
96
+ """
97
+ def __init__(self, in_dim, hidden_dim=64, embed_dim=32,
98
+ alpha=0.5, dropout=0.3):
99
+ super().__init__()
100
+ self.alpha = alpha
101
+ self.encoder = GCNEncoder(in_dim, hidden_dim, embed_dim, dropout)
102
+ self.attr_dec = AttributeDecoder(embed_dim, hidden_dim, in_dim, dropout)
103
+ self.struct_dec = StructureDecoder(embed_dim, hidden_dim, dropout)
104
+
105
+ def forward(self, x, edge_index, edge_weight, n_nos):
106
+ # Encode
107
+ z = self.encoder(x, edge_index, edge_weight, n_nos)
108
+
109
+ # Decode atributos
110
+ x_hat = self.attr_dec(z, edge_index, edge_weight, n_nos)
111
+
112
+ # Decode estrutura
113
+ a_hat, h_struct = self.struct_dec(z, edge_index, edge_weight, n_nos)
114
+
115
+ return z, x_hat, a_hat, h_struct
116
+
117
+ def compute_loss(self, x, edge_index, x_hat, a_hat):
118
+ """
119
+ L_attr = ||X - XΜ‚||Β² por nΓ³
120
+ L_struct = BCE(A, Γ‚) por aresta β†’ agregado por nΓ³
121
+ """
122
+ row, col = edge_index
123
+
124
+ # Erro de atributo por nΓ³
125
+ err_attr = ((x - x_hat) ** 2).mean(dim=1) # [N]
126
+
127
+ # Erro de estrutura por aresta
128
+ a_true = torch.ones(edge_index.shape[1]).to(x.device)
129
+ err_edge = F.binary_cross_entropy(a_hat, a_true, reduction='none')
130
+
131
+ # Agrega erro estrutural por nΓ³ (mΓ©dia das arestas incidentes)
132
+ err_struct = torch.zeros(x.shape[0]).to(x.device)
133
+ count = torch.zeros(x.shape[0]).to(x.device)
134
+ err_struct.scatter_add_(0, row, err_edge)
135
+ count.scatter_add_(0, row, torch.ones_like(err_edge))
136
+ count = count.clamp(min=1)
137
+ err_struct = err_struct / count
138
+
139
+ # Loss total
140
+ loss = (self.alpha * err_struct + (1 - self.alpha) * err_attr).mean()
141
+ return loss, err_attr.detach(), err_struct.detach()
142
+
143
+ def anomaly_score(self, err_attr, err_struct):
144
+ """Score de anomalia combinado β€” maior = mais suspeito."""
145
+ score = self.alpha * err_struct + (1 - self.alpha) * err_attr
146
+ # Normaliza para [0, 1]
147
+ mn, mx = score.min(), score.max()
148
+ return (score - mn) / (mx - mn + 1e-8)
149
+
150
+
151
+ # ──────────────────────────────────────────────────────────────
152
+ # TRAINER
153
+ # ──────────────────────────────────────────────────────────────
154
+ class TrainerDOMINANT:
155
+ def __init__(self, data, edge_weight, hidden_dim=64, embed_dim=32,
156
+ alpha=0.5, lr=0.005, dropout=0.3):
157
+ self.data = data
158
+ self.edge_index = data.edge_index
159
+ self.edge_weight = edge_weight
160
+ self.n_nos = data.x.shape[0]
161
+
162
+ self.model = DOMINANT(
163
+ in_dim=data.x.shape[1],
164
+ hidden_dim=hidden_dim,
165
+ embed_dim=embed_dim,
166
+ alpha=alpha,
167
+ dropout=dropout,
168
+ )
169
+ self.opt = torch.optim.Adam(
170
+ self.model.parameters(), lr=lr, weight_decay=1e-4)
171
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
172
+ self.opt, patience=10, factor=0.5, min_lr=1e-5)
173
+
174
+ self.historico = {'loss': [], 'auc': []}
175
+ self.melhor_auc = 0.0
176
+ self.melhor_estado = None
177
+ self.scores_finais = None
178
+ self.embeddings = None
179
+
180
+ def treinar_epoca(self):
181
+ self.model.train()
182
+ z, x_hat, a_hat, _ = self.model(
183
+ self.data.x, self.edge_index,
184
+ self.edge_weight, self.n_nos)
185
+ loss, err_attr, err_struct = self.model.compute_loss(
186
+ self.data.x, self.edge_index, x_hat, a_hat)
187
+ self.opt.zero_grad()
188
+ loss.backward()
189
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
190
+ self.opt.step()
191
+ return loss.item(), err_attr, err_struct
192
+
193
+ def avaliar(self, err_attr, err_struct):
194
+ scores = self.model.anomaly_score(err_attr, err_struct).numpy()
195
+ y_true = self.data.y.numpy()
196
+ auc = roc_auc_score(y_true, scores) if len(np.unique(y_true)) > 1 else 0.5
197
+ return auc, scores
198
+
199
+ def treinar(self, epocas=100, callback=None):
200
+ for ep in range(1, epocas + 1):
201
+ loss, err_attr, err_struct = self.treinar_epoca()
202
+ auc, scores = self.avaliar(err_attr, err_struct)
203
+ self.scheduler.step(loss)
204
+
205
+ self.historico['loss'].append(loss)
206
+ self.historico['auc'].append(auc)
207
+
208
+ if auc > self.melhor_auc:
209
+ self.melhor_auc = auc
210
+ self.melhor_estado = {k: v.clone()
211
+ for k, v in self.model.state_dict().items()}
212
+ self.scores_finais = scores
213
+
214
+ if callback:
215
+ callback(ep, epocas, loss, auc)
216
+
217
+ if self.melhor_estado:
218
+ self.model.load_state_dict(self.melhor_estado)
219
+
220
+ def metricas_completas(self):
221
+ self.model.eval()
222
+ with torch.no_grad():
223
+ z, x_hat, a_hat, _ = self.model(
224
+ self.data.x, self.edge_index,
225
+ self.edge_weight, self.n_nos)
226
+ _, err_attr, err_struct = self.model.compute_loss(
227
+ self.data.x, self.edge_index, x_hat, a_hat)
228
+
229
+ scores = self.model.anomaly_score(err_attr, err_struct).numpy()
230
+ y_true = self.data.y.numpy()
231
+ self.embeddings = z.detach().numpy()
232
+ self.scores_finais = scores
233
+
234
+ # Threshold via percentil (top-k como na literatura)
235
+ k = int(y_true.sum())
236
+ thresh = np.sort(scores)[-k] if k > 0 else 0.5
237
+ preds = (scores >= thresh).astype(int)
238
+
239
+ # DecomposiΓ§Γ£o por tipo de erro
240
+ err_a = err_attr.numpy()
241
+ err_s = err_struct.numpy()
242
+
243
+ return {
244
+ 'auc': roc_auc_score(y_true, scores),
245
+ 'ap': average_precision_score(y_true, scores),
246
+ 'f1': f1_score(y_true, preds, zero_division=0),
247
+ 'precision': precision_score(y_true, preds, zero_division=0),
248
+ 'recall': recall_score(y_true, preds, zero_division=0),
249
+ 'scores': scores,
250
+ 'y_true': y_true,
251
+ 'err_attr': err_a,
252
+ 'err_struct': err_s,
253
+ 'embeddings': self.embeddings,
254
+ 'thresh': thresh,
255
+ 'preds': preds,
256
+ }
257
+
258
+ def get_top_anomalias(self, n=20):
259
+ """Retorna os nΓ³s mais anΓ΄malos com decomposiΓ§Γ£o de erro."""
260
+ if self.scores_finais is None:
261
+ return []
262
+ top_idx = np.argsort(self.scores_finais)[::-1][:n]
263
+ result = []
264
+ for idx in top_idx:
265
+ result.append({
266
+ 'idx': int(idx),
267
+ 'score': float(self.scores_finais[idx]),
268
+ 'label_real': int(self.data.y[idx]),
269
+ })
270
+ return result