""" relgnn/trainer.py Loop de treinamento do RelGNN. Extrai features numéricas diretamente das tabelas SQL (sem grafo), agrega por entidade alvo (customers), e treina end-to-end. """ import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from sklearn.model_selection import train_test_split from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score from typing import Dict, List, Tuple, Callable, Optional from data.routes import AtomicRoute # ─── FEATURE EXTRACTION ─────────────────────────────────────────────────────── NUMERIC_COLS = { "customers": ["c_acctbal", "c_nationkey", "c_account_age_days", "c_num_prev_orders"], "orders": ["o_totalprice", "o_shippriority"], "lineitem": ["l_quantity", "l_extendedprice", "l_discount", "l_tax"], "supplier": ["s_acctbal", "s_nationkey", "s_risk_flag"], "nation": ["n_nationkey", "n_regionkey"], "part": ["p_retailprice"], } def extract_features(tables: Dict, n_customers: int) -> Tuple[Dict, np.ndarray]: """ Extrai features numéricas das tabelas e agrega por cliente (entidade alvo). Retorna: table_features: {table_name: np.ndarray [n_customers, feature_dim]} labels: np.ndarray [n_customers] (is_fraud) """ import pandas as pd customers = tables["customers"] orders = tables["orders"] # Labels: 1 se algum pedido do cliente é fraude fraud_by_customer = orders.groupby("o_custkey")["is_fraud"].max() labels = customers["c_custkey"].map(fraud_by_customer).fillna(0).values.astype(float) table_features = {} # ── Customers: direto ───────────────────────────────────────────────────── cols = [c for c in NUMERIC_COLS["customers"] if c in customers.columns] table_features["customers"] = customers[cols].fillna(0).values.astype(np.float32) # ── Orders: agrega por cliente (mean + max + count) ─────────────────────── order_cols = [c for c in NUMERIC_COLS["orders"] if c in orders.columns] ord_mean = orders.groupby("o_custkey")[order_cols].mean() ord_max = orders.groupby("o_custkey")[order_cols].max() ord_cnt = orders.groupby("o_custkey").size().rename("order_count") ord_agg = ord_mean.join(ord_max, rsuffix="_max").join(ord_cnt) ord_agg = customers[["c_custkey"]].set_index("c_custkey").join(ord_agg).fillna(0) table_features["orders"] = ord_agg.values.astype(np.float32) # ── Lineitem: agrega via orders → customer ──────────────────────────────── lineitem = tables["lineitem"] li_cols = [c for c in NUMERIC_COLS["lineitem"] if c in lineitem.columns] li_with_cust = lineitem.merge( orders[["o_orderkey", "o_custkey"]], on="o_orderkey", how="left" ) li_mean = li_with_cust.groupby("o_custkey")[li_cols].mean() li_max = li_with_cust.groupby("o_custkey")[li_cols].max() li_cnt = li_with_cust.groupby("o_custkey").size().rename("lineitem_count") li_agg = li_mean.join(li_max, rsuffix="_max").join(li_cnt) li_agg = customers[["c_custkey"]].set_index("c_custkey").join(li_agg).fillna(0) table_features["lineitem"] = li_agg.values.astype(np.float32) # ── Supplier: agrega via lineitem → orders → customer ──────────────────── supplier = tables["supplier"] sup_cols = [c for c in NUMERIC_COLS["supplier"] if c in supplier.columns] sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey", right_on="s_suppkey", how="left") sup_mean = sup_with_cust.groupby("o_custkey")[sup_cols].mean() sup_agg = customers[["c_custkey"]].set_index("c_custkey").join(sup_mean).fillna(0) table_features["supplier"] = sup_agg.values.astype(np.float32) # ── Nation: join direto ─────────────────────────────────────────────────── nation = tables["nation"] nat_cols = [c for c in NUMERIC_COLS["nation"] if c in nation.columns] nat_agg = customers[["c_custkey", "c_nationkey"]].merge( nation, left_on="c_nationkey", right_on="n_nationkey", how="left" )[nat_cols].fillna(0) table_features["nation"] = nat_agg.values.astype(np.float32) # ── Part: agrega via lineitem → customer ────────────────────────────────── part = tables["part"] par_cols = [c for c in NUMERIC_COLS["part"] if c in part.columns] par_with_cust = li_with_cust.merge(part, left_on="l_partkey", right_on="p_partkey", how="left") par_mean = par_with_cust.groupby("o_custkey")[par_cols].mean() par_agg = customers[["c_custkey"]].set_index("c_custkey").join(par_mean).fillna(0) table_features["part"] = par_agg.values.astype(np.float32) # Normaliza features (min-max por coluna) for key in table_features: feat = table_features[key] col_min = feat.min(axis=0, keepdims=True) col_max = feat.max(axis=0, keepdims=True) denom = np.where((col_max - col_min) == 0, 1, col_max - col_min) table_features[key] = (feat - col_min) / denom return table_features, labels # ─── TRAINER ───────────────────────────────────────────────────────────────── class Trainer: def __init__(self, model, config): self.model = model self.config = config def fit( self, tables: Dict, routes: List[AtomicRoute], log_fn: Callable = print, progress_fn=None, ) -> Tuple[Dict, List[Dict]]: t_start = time.time() H = self.config.hidden_dim D = self.config.dropout LR = self.config.learning_rate EPOCHS = self.config.num_epochs # 1. Extrai features table_features_np, labels = extract_features(tables, len(tables["customers"])) feature_dims = {k: v.shape[1] for k, v in table_features_np.items()} # 2. Build do modelo (agora que sabemos as dims) self.model.build(feature_dims, routes) optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) # 3. Split treino/teste estratificado n = len(labels) idx = np.arange(n) idx_tr, idx_te = train_test_split(idx, test_size=0.2, random_state=42, stratify=(labels > 0.5).astype(int)) def to_tensor(feat_dict, idx): return {k: torch.tensor(v[idx], dtype=torch.float32) for k, v in feat_dict.items()} y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32) y_te = torch.tensor(labels[idx_te], dtype=torch.float32) # Peso para classe positiva (fraude é rara) pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)]) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) history = [] log_interval = max(1, EPOCHS // 10) self.model.train() for epoch in range(1, EPOCHS + 1): optimizer.zero_grad() feat_tr = to_tensor(table_features_np, idx_tr) logits, _ = self.model(feat_tr) loss = loss_fn(logits, y_tr) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() scheduler.step() if epoch % log_interval == 0 or epoch == EPOCHS: self.model.eval() with torch.no_grad(): feat_te = to_tensor(table_features_np, idx_te) logits_te, _ = self.model(feat_te) probs_te = torch.sigmoid(logits_te).numpy() try: auc = roc_auc_score(labels[idx_te], probs_te) except Exception: auc = 0.5 history.append({"epoch": epoch, "loss": float(loss), "auc": auc}) if epoch % (log_interval * 2) == 0 or epoch == EPOCHS: log_fn(f" Época {epoch:3d}/{EPOCHS} | Loss: {float(loss):.4f} | AUC: {auc:.4f}") self.model.train() if progress_fn: pct = 0.30 + 0.35 * (epoch / EPOCHS) progress_fn(pct, desc=f"RelGNN treino — época {epoch}/{EPOCHS}") # Métricas finais self.model.eval() with torch.no_grad(): feat_te = to_tensor(table_features_np, idx_te) logits_te, attn_info = self.model(feat_te) probs_te = torch.sigmoid(logits_te).numpy() preds = (probs_te > 0.5).astype(int) y_true = labels[idx_te].astype(int) try: auc = roc_auc_score(y_true, probs_te) f1 = f1_score(y_true, preds, zero_division=0) precision = precision_score(y_true, preds, zero_division=0) recall = recall_score(y_true, preds, zero_division=0) except Exception: auc = f1 = precision = recall = 0.5 train_time = round(time.time() - t_start, 1) metrics = { "auc": round(auc, 4), "f1": round(f1, 4), "precision": round(precision, 4), "recall": round(recall, 4), "train_time": train_time, } # Atualiza pesos de atenção nas rotas com valores reais route_weights = torch.softmax(self.model.hierarchical.route_weights, dim=0) for i, route in enumerate(routes): if i < len(route_weights): route.attention_weight = float(route_weights[i].item()) route.active = route.attention_weight > 0.15 return metrics, history