Danielfonseca1212 commited on
Commit
cf49f9c
Β·
verified Β·
1 Parent(s): c9959a3

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +232 -0
trainer.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ relgnn/trainer.py
3
+ Loop de treinamento do RelGNN.
4
+
5
+ Extrai features numΓ©ricas diretamente das tabelas SQL (sem grafo),
6
+ agrega por entidade alvo (customers), e treina end-to-end.
7
+ """
8
+
9
+ import time
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
16
+ from typing import Dict, List, Tuple, Callable, Optional
17
+
18
+ from data.routes import AtomicRoute
19
+
20
+
21
+ # ─── FEATURE EXTRACTION ───────────────────────────────────────────────────────
22
+
23
+ NUMERIC_COLS = {
24
+ "customers": ["c_acctbal", "c_nationkey", "c_account_age_days", "c_num_prev_orders"],
25
+ "orders": ["o_totalprice", "o_shippriority"],
26
+ "lineitem": ["l_quantity", "l_extendedprice", "l_discount", "l_tax"],
27
+ "supplier": ["s_acctbal", "s_nationkey", "s_risk_flag"],
28
+ "nation": ["n_nationkey", "n_regionkey"],
29
+ "part": ["p_retailprice"],
30
+ }
31
+
32
+
33
+ def extract_features(tables: Dict, n_customers: int) -> Tuple[Dict, np.ndarray]:
34
+ """
35
+ Extrai features numΓ©ricas das tabelas e agrega por cliente (entidade alvo).
36
+
37
+ Retorna:
38
+ table_features: {table_name: np.ndarray [n_customers, feature_dim]}
39
+ labels: np.ndarray [n_customers] (is_fraud)
40
+ """
41
+ import pandas as pd
42
+ customers = tables["customers"]
43
+ orders = tables["orders"]
44
+
45
+ # Labels: 1 se algum pedido do cliente Γ© fraude
46
+ fraud_by_customer = orders.groupby("o_custkey")["is_fraud"].max()
47
+ labels = customers["c_custkey"].map(fraud_by_customer).fillna(0).values.astype(float)
48
+
49
+ table_features = {}
50
+
51
+ # ── Customers: direto ─────────────────────────────────────────────────────
52
+ cols = [c for c in NUMERIC_COLS["customers"] if c in customers.columns]
53
+ table_features["customers"] = customers[cols].fillna(0).values.astype(np.float32)
54
+
55
+ # ── Orders: agrega por cliente (mean + max + count) ───────────────────────
56
+ order_cols = [c for c in NUMERIC_COLS["orders"] if c in orders.columns]
57
+ ord_mean = orders.groupby("o_custkey")[order_cols].mean()
58
+ ord_max = orders.groupby("o_custkey")[order_cols].max()
59
+ ord_cnt = orders.groupby("o_custkey").size().rename("order_count")
60
+
61
+ ord_agg = ord_mean.join(ord_max, rsuffix="_max").join(ord_cnt)
62
+ ord_agg = customers[["c_custkey"]].set_index("c_custkey").join(ord_agg).fillna(0)
63
+ table_features["orders"] = ord_agg.values.astype(np.float32)
64
+
65
+ # ── Lineitem: agrega via orders β†’ customer ────────────────────────────────
66
+ lineitem = tables["lineitem"]
67
+ li_cols = [c for c in NUMERIC_COLS["lineitem"] if c in lineitem.columns]
68
+ li_with_cust = lineitem.merge(
69
+ orders[["o_orderkey", "o_custkey"]], on="o_orderkey", how="left"
70
+ )
71
+ li_mean = li_with_cust.groupby("o_custkey")[li_cols].mean()
72
+ li_max = li_with_cust.groupby("o_custkey")[li_cols].max()
73
+ li_cnt = li_with_cust.groupby("o_custkey").size().rename("lineitem_count")
74
+ li_agg = li_mean.join(li_max, rsuffix="_max").join(li_cnt)
75
+ li_agg = customers[["c_custkey"]].set_index("c_custkey").join(li_agg).fillna(0)
76
+ table_features["lineitem"] = li_agg.values.astype(np.float32)
77
+
78
+ # ── Supplier: agrega via lineitem β†’ orders β†’ customer ────────────────────
79
+ supplier = tables["supplier"]
80
+ sup_cols = [c for c in NUMERIC_COLS["supplier"] if c in supplier.columns]
81
+ sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey", right_on="s_suppkey", how="left")
82
+ sup_mean = sup_with_cust.groupby("o_custkey")[sup_cols].mean()
83
+ sup_agg = customers[["c_custkey"]].set_index("c_custkey").join(sup_mean).fillna(0)
84
+ table_features["supplier"] = sup_agg.values.astype(np.float32)
85
+
86
+ # ── Nation: join direto ───────────────────────────────────────────────────
87
+ nation = tables["nation"]
88
+ nat_cols = [c for c in NUMERIC_COLS["nation"] if c in nation.columns]
89
+ nat_agg = customers[["c_custkey", "c_nationkey"]].merge(
90
+ nation, left_on="c_nationkey", right_on="n_nationkey", how="left"
91
+ )[nat_cols].fillna(0)
92
+ table_features["nation"] = nat_agg.values.astype(np.float32)
93
+
94
+ # ── Part: agrega via lineitem β†’ customer ──────────────────────────────────
95
+ part = tables["part"]
96
+ par_cols = [c for c in NUMERIC_COLS["part"] if c in part.columns]
97
+ par_with_cust = li_with_cust.merge(part, left_on="l_partkey", right_on="p_partkey", how="left")
98
+ par_mean = par_with_cust.groupby("o_custkey")[par_cols].mean()
99
+ par_agg = customers[["c_custkey"]].set_index("c_custkey").join(par_mean).fillna(0)
100
+ table_features["part"] = par_agg.values.astype(np.float32)
101
+
102
+ # Normaliza features (min-max por coluna)
103
+ for key in table_features:
104
+ feat = table_features[key]
105
+ col_min = feat.min(axis=0, keepdims=True)
106
+ col_max = feat.max(axis=0, keepdims=True)
107
+ denom = np.where((col_max - col_min) == 0, 1, col_max - col_min)
108
+ table_features[key] = (feat - col_min) / denom
109
+
110
+ return table_features, labels
111
+
112
+
113
+ # ─── TRAINER ─────────────────────────────────────────────────────────────────
114
+
115
+ class Trainer:
116
+ def __init__(self, model, config):
117
+ self.model = model
118
+ self.config = config
119
+
120
+ def fit(
121
+ self,
122
+ tables: Dict,
123
+ routes: List[AtomicRoute],
124
+ log_fn: Callable = print,
125
+ progress_fn=None,
126
+ ) -> Tuple[Dict, List[Dict]]:
127
+
128
+ t_start = time.time()
129
+ H = self.config.hidden_dim
130
+ D = self.config.dropout
131
+ LR = self.config.learning_rate
132
+ EPOCHS = self.config.num_epochs
133
+
134
+ # 1. Extrai features
135
+ table_features_np, labels = extract_features(tables, len(tables["customers"]))
136
+
137
+ feature_dims = {k: v.shape[1] for k, v in table_features_np.items()}
138
+
139
+ # 2. Build do modelo (agora que sabemos as dims)
140
+ self.model.build(feature_dims, routes)
141
+ optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4)
142
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
143
+
144
+ # 3. Split treino/teste estratificado
145
+ n = len(labels)
146
+ idx = np.arange(n)
147
+ idx_tr, idx_te = train_test_split(idx, test_size=0.2, random_state=42,
148
+ stratify=(labels > 0.5).astype(int))
149
+
150
+ def to_tensor(feat_dict, idx):
151
+ return {k: torch.tensor(v[idx], dtype=torch.float32)
152
+ for k, v in feat_dict.items()}
153
+
154
+ y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32)
155
+ y_te = torch.tensor(labels[idx_te], dtype=torch.float32)
156
+
157
+ # Peso para classe positiva (fraude Γ© rara)
158
+ pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)])
159
+ loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
160
+
161
+ history = []
162
+ log_interval = max(1, EPOCHS // 10)
163
+
164
+ self.model.train()
165
+ for epoch in range(1, EPOCHS + 1):
166
+ optimizer.zero_grad()
167
+ feat_tr = to_tensor(table_features_np, idx_tr)
168
+ logits, _ = self.model(feat_tr)
169
+ loss = loss_fn(logits, y_tr)
170
+ loss.backward()
171
+ nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
172
+ optimizer.step()
173
+ scheduler.step()
174
+
175
+ if epoch % log_interval == 0 or epoch == EPOCHS:
176
+ self.model.eval()
177
+ with torch.no_grad():
178
+ feat_te = to_tensor(table_features_np, idx_te)
179
+ logits_te, _ = self.model(feat_te)
180
+ probs_te = torch.sigmoid(logits_te).numpy()
181
+
182
+ try:
183
+ auc = roc_auc_score(labels[idx_te], probs_te)
184
+ except Exception:
185
+ auc = 0.5
186
+
187
+ history.append({"epoch": epoch, "loss": float(loss), "auc": auc})
188
+ if epoch % (log_interval * 2) == 0 or epoch == EPOCHS:
189
+ log_fn(f" Γ‰poca {epoch:3d}/{EPOCHS} | Loss: {float(loss):.4f} | AUC: {auc:.4f}")
190
+
191
+ self.model.train()
192
+
193
+ if progress_fn:
194
+ pct = 0.30 + 0.35 * (epoch / EPOCHS)
195
+ progress_fn(pct, desc=f"RelGNN treino β€” Γ©poca {epoch}/{EPOCHS}")
196
+
197
+ # MΓ©tricas finais
198
+ self.model.eval()
199
+ with torch.no_grad():
200
+ feat_te = to_tensor(table_features_np, idx_te)
201
+ logits_te, attn_info = self.model(feat_te)
202
+ probs_te = torch.sigmoid(logits_te).numpy()
203
+
204
+ preds = (probs_te > 0.5).astype(int)
205
+ y_true = labels[idx_te].astype(int)
206
+
207
+ try:
208
+ auc = roc_auc_score(y_true, probs_te)
209
+ f1 = f1_score(y_true, preds, zero_division=0)
210
+ precision = precision_score(y_true, preds, zero_division=0)
211
+ recall = recall_score(y_true, preds, zero_division=0)
212
+ except Exception:
213
+ auc = f1 = precision = recall = 0.5
214
+
215
+ train_time = round(time.time() - t_start, 1)
216
+
217
+ metrics = {
218
+ "auc": round(auc, 4),
219
+ "f1": round(f1, 4),
220
+ "precision": round(precision, 4),
221
+ "recall": round(recall, 4),
222
+ "train_time": train_time,
223
+ }
224
+
225
+ # Atualiza pesos de atenΓ§Γ£o nas rotas com valores reais
226
+ route_weights = torch.softmax(self.model.hierarchical.route_weights, dim=0)
227
+ for i, route in enumerate(routes):
228
+ if i < len(route_weights):
229
+ route.attention_weight = float(route_weights[i].item())
230
+ route.active = route.attention_weight > 0.15
231
+
232
+ return metrics, history