| """ |
| relgnn/trainer.py |
| Loop de treinamento do RelGNN. |
| |
| Extrai features numΓ©ricas diretamente das tabelas SQL (sem grafo), |
| agrega por entidade alvo (customers), e treina end-to-end. |
| """ |
|
|
| import time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score |
| from typing import Dict, List, Tuple, Callable, Optional |
|
|
| from data.routes import AtomicRoute |
|
|
|
|
| |
|
|
| NUMERIC_COLS = { |
| "customers": ["c_acctbal", "c_nationkey", "c_account_age_days", "c_num_prev_orders"], |
| "orders": ["o_totalprice", "o_shippriority"], |
| "lineitem": ["l_quantity", "l_extendedprice", "l_discount", "l_tax"], |
| "supplier": ["s_acctbal", "s_nationkey", "s_risk_flag"], |
| "nation": ["n_nationkey", "n_regionkey"], |
| "part": ["p_retailprice"], |
| } |
|
|
|
|
| def extract_features(tables: Dict, n_customers: int) -> Tuple[Dict, np.ndarray]: |
| """ |
| Extrai features numΓ©ricas das tabelas e agrega por cliente (entidade alvo). |
| |
| Retorna: |
| table_features: {table_name: np.ndarray [n_customers, feature_dim]} |
| labels: np.ndarray [n_customers] (is_fraud) |
| """ |
| import pandas as pd |
| customers = tables["customers"] |
| orders = tables["orders"] |
|
|
| |
| fraud_by_customer = orders.groupby("o_custkey")["is_fraud"].max() |
| labels = customers["c_custkey"].map(fraud_by_customer).fillna(0).values.astype(float) |
|
|
| table_features = {} |
|
|
| |
| cols = [c for c in NUMERIC_COLS["customers"] if c in customers.columns] |
| table_features["customers"] = customers[cols].fillna(0).values.astype(np.float32) |
|
|
| |
| order_cols = [c for c in NUMERIC_COLS["orders"] if c in orders.columns] |
| ord_mean = orders.groupby("o_custkey")[order_cols].mean() |
| ord_max = orders.groupby("o_custkey")[order_cols].max() |
| ord_cnt = orders.groupby("o_custkey").size().rename("order_count") |
|
|
| ord_agg = ord_mean.join(ord_max, rsuffix="_max").join(ord_cnt) |
| ord_agg = customers[["c_custkey"]].set_index("c_custkey").join(ord_agg).fillna(0) |
| table_features["orders"] = ord_agg.values.astype(np.float32) |
|
|
| |
| lineitem = tables["lineitem"] |
| li_cols = [c for c in NUMERIC_COLS["lineitem"] if c in lineitem.columns] |
| li_with_cust = lineitem.merge( |
| orders[["o_orderkey", "o_custkey"]], on="o_orderkey", how="left" |
| ) |
| li_mean = li_with_cust.groupby("o_custkey")[li_cols].mean() |
| li_max = li_with_cust.groupby("o_custkey")[li_cols].max() |
| li_cnt = li_with_cust.groupby("o_custkey").size().rename("lineitem_count") |
| li_agg = li_mean.join(li_max, rsuffix="_max").join(li_cnt) |
| li_agg = customers[["c_custkey"]].set_index("c_custkey").join(li_agg).fillna(0) |
| table_features["lineitem"] = li_agg.values.astype(np.float32) |
|
|
| |
| supplier = tables["supplier"] |
| sup_cols = [c for c in NUMERIC_COLS["supplier"] if c in supplier.columns] |
| sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey", right_on="s_suppkey", how="left") |
| sup_mean = sup_with_cust.groupby("o_custkey")[sup_cols].mean() |
| sup_agg = customers[["c_custkey"]].set_index("c_custkey").join(sup_mean).fillna(0) |
| table_features["supplier"] = sup_agg.values.astype(np.float32) |
|
|
| |
| nation = tables["nation"] |
| nat_cols = [c for c in NUMERIC_COLS["nation"] if c in nation.columns] |
| nat_agg = customers[["c_custkey", "c_nationkey"]].merge( |
| nation, left_on="c_nationkey", right_on="n_nationkey", how="left" |
| )[nat_cols].fillna(0) |
| table_features["nation"] = nat_agg.values.astype(np.float32) |
|
|
| |
| part = tables["part"] |
| par_cols = [c for c in NUMERIC_COLS["part"] if c in part.columns] |
| par_with_cust = li_with_cust.merge(part, left_on="l_partkey", right_on="p_partkey", how="left") |
| par_mean = par_with_cust.groupby("o_custkey")[par_cols].mean() |
| par_agg = customers[["c_custkey"]].set_index("c_custkey").join(par_mean).fillna(0) |
| table_features["part"] = par_agg.values.astype(np.float32) |
|
|
| |
| for key in table_features: |
| feat = table_features[key] |
| col_min = feat.min(axis=0, keepdims=True) |
| col_max = feat.max(axis=0, keepdims=True) |
| denom = np.where((col_max - col_min) == 0, 1, col_max - col_min) |
| table_features[key] = (feat - col_min) / denom |
|
|
| return table_features, labels |
|
|
|
|
| |
|
|
| class Trainer: |
| def __init__(self, model, config): |
| self.model = model |
| self.config = config |
|
|
| def fit( |
| self, |
| tables: Dict, |
| routes: List[AtomicRoute], |
| log_fn: Callable = print, |
| progress_fn=None, |
| ) -> Tuple[Dict, List[Dict]]: |
|
|
| t_start = time.time() |
| H = self.config.hidden_dim |
| D = self.config.dropout |
| LR = self.config.learning_rate |
| EPOCHS = self.config.num_epochs |
|
|
| |
| table_features_np, labels = extract_features(tables, len(tables["customers"])) |
|
|
| feature_dims = {k: v.shape[1] for k, v in table_features_np.items()} |
|
|
| |
| self.model.build(feature_dims, routes) |
| optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) |
|
|
| |
| n = len(labels) |
| idx = np.arange(n) |
| idx_tr, idx_te = train_test_split(idx, test_size=0.2, random_state=42, |
| stratify=(labels > 0.5).astype(int)) |
|
|
| def to_tensor(feat_dict, idx): |
| return {k: torch.tensor(v[idx], dtype=torch.float32) |
| for k, v in feat_dict.items()} |
|
|
| y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32) |
| y_te = torch.tensor(labels[idx_te], dtype=torch.float32) |
|
|
| |
| pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)]) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
| history = [] |
| log_interval = max(1, EPOCHS // 10) |
|
|
| self.model.train() |
| for epoch in range(1, EPOCHS + 1): |
| optimizer.zero_grad() |
| feat_tr = to_tensor(table_features_np, idx_tr) |
| logits, _ = self.model(feat_tr) |
| loss = loss_fn(logits, y_tr) |
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| if epoch % log_interval == 0 or epoch == EPOCHS: |
| self.model.eval() |
| with torch.no_grad(): |
| feat_te = to_tensor(table_features_np, idx_te) |
| logits_te, _ = self.model(feat_te) |
| probs_te = torch.sigmoid(logits_te).numpy() |
|
|
| try: |
| auc = roc_auc_score(labels[idx_te], probs_te) |
| except Exception: |
| auc = 0.5 |
|
|
| history.append({"epoch": epoch, "loss": float(loss), "auc": auc}) |
| if epoch % (log_interval * 2) == 0 or epoch == EPOCHS: |
| log_fn(f" Γpoca {epoch:3d}/{EPOCHS} | Loss: {float(loss):.4f} | AUC: {auc:.4f}") |
|
|
| self.model.train() |
|
|
| if progress_fn: |
| pct = 0.30 + 0.35 * (epoch / EPOCHS) |
| progress_fn(pct, desc=f"RelGNN treino β Γ©poca {epoch}/{EPOCHS}") |
|
|
| |
| self.model.eval() |
| with torch.no_grad(): |
| feat_te = to_tensor(table_features_np, idx_te) |
| logits_te, attn_info = self.model(feat_te) |
| probs_te = torch.sigmoid(logits_te).numpy() |
|
|
| preds = (probs_te > 0.5).astype(int) |
| y_true = labels[idx_te].astype(int) |
|
|
| try: |
| auc = roc_auc_score(y_true, probs_te) |
| f1 = f1_score(y_true, preds, zero_division=0) |
| precision = precision_score(y_true, preds, zero_division=0) |
| recall = recall_score(y_true, preds, zero_division=0) |
| except Exception: |
| auc = f1 = precision = recall = 0.5 |
|
|
| train_time = round(time.time() - t_start, 1) |
|
|
| metrics = { |
| "auc": round(auc, 4), |
| "f1": round(f1, 4), |
| "precision": round(precision, 4), |
| "recall": round(recall, 4), |
| "train_time": train_time, |
| } |
|
|
| |
| route_weights = torch.softmax(self.model.hierarchical.route_weights, dim=0) |
| for i, route in enumerate(routes): |
| if i < len(route_weights): |
| route.attention_weight = float(route_weights[i].item()) |
| route.active = route.attention_weight > 0.15 |
|
|
| return metrics, history |