Danielfonseca1212's picture
Create trainer.py
cf49f9c verified
"""
relgnn/trainer.py
Loop de treinamento do RelGNN.
Extrai features numΓ©ricas diretamente das tabelas SQL (sem grafo),
agrega por entidade alvo (customers), e treina end-to-end.
"""
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from typing import Dict, List, Tuple, Callable, Optional
from data.routes import AtomicRoute
# ─── FEATURE EXTRACTION ───────────────────────────────────────────────────────
NUMERIC_COLS = {
"customers": ["c_acctbal", "c_nationkey", "c_account_age_days", "c_num_prev_orders"],
"orders": ["o_totalprice", "o_shippriority"],
"lineitem": ["l_quantity", "l_extendedprice", "l_discount", "l_tax"],
"supplier": ["s_acctbal", "s_nationkey", "s_risk_flag"],
"nation": ["n_nationkey", "n_regionkey"],
"part": ["p_retailprice"],
}
def extract_features(tables: Dict, n_customers: int) -> Tuple[Dict, np.ndarray]:
"""
Extrai features numΓ©ricas das tabelas e agrega por cliente (entidade alvo).
Retorna:
table_features: {table_name: np.ndarray [n_customers, feature_dim]}
labels: np.ndarray [n_customers] (is_fraud)
"""
import pandas as pd
customers = tables["customers"]
orders = tables["orders"]
# Labels: 1 se algum pedido do cliente Γ© fraude
fraud_by_customer = orders.groupby("o_custkey")["is_fraud"].max()
labels = customers["c_custkey"].map(fraud_by_customer).fillna(0).values.astype(float)
table_features = {}
# ── Customers: direto ─────────────────────────────────────────────────────
cols = [c for c in NUMERIC_COLS["customers"] if c in customers.columns]
table_features["customers"] = customers[cols].fillna(0).values.astype(np.float32)
# ── Orders: agrega por cliente (mean + max + count) ───────────────────────
order_cols = [c for c in NUMERIC_COLS["orders"] if c in orders.columns]
ord_mean = orders.groupby("o_custkey")[order_cols].mean()
ord_max = orders.groupby("o_custkey")[order_cols].max()
ord_cnt = orders.groupby("o_custkey").size().rename("order_count")
ord_agg = ord_mean.join(ord_max, rsuffix="_max").join(ord_cnt)
ord_agg = customers[["c_custkey"]].set_index("c_custkey").join(ord_agg).fillna(0)
table_features["orders"] = ord_agg.values.astype(np.float32)
# ── Lineitem: agrega via orders β†’ customer ────────────────────────────────
lineitem = tables["lineitem"]
li_cols = [c for c in NUMERIC_COLS["lineitem"] if c in lineitem.columns]
li_with_cust = lineitem.merge(
orders[["o_orderkey", "o_custkey"]], on="o_orderkey", how="left"
)
li_mean = li_with_cust.groupby("o_custkey")[li_cols].mean()
li_max = li_with_cust.groupby("o_custkey")[li_cols].max()
li_cnt = li_with_cust.groupby("o_custkey").size().rename("lineitem_count")
li_agg = li_mean.join(li_max, rsuffix="_max").join(li_cnt)
li_agg = customers[["c_custkey"]].set_index("c_custkey").join(li_agg).fillna(0)
table_features["lineitem"] = li_agg.values.astype(np.float32)
# ── Supplier: agrega via lineitem β†’ orders β†’ customer ────────────────────
supplier = tables["supplier"]
sup_cols = [c for c in NUMERIC_COLS["supplier"] if c in supplier.columns]
sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey", right_on="s_suppkey", how="left")
sup_mean = sup_with_cust.groupby("o_custkey")[sup_cols].mean()
sup_agg = customers[["c_custkey"]].set_index("c_custkey").join(sup_mean).fillna(0)
table_features["supplier"] = sup_agg.values.astype(np.float32)
# ── Nation: join direto ───────────────────────────────────────────────────
nation = tables["nation"]
nat_cols = [c for c in NUMERIC_COLS["nation"] if c in nation.columns]
nat_agg = customers[["c_custkey", "c_nationkey"]].merge(
nation, left_on="c_nationkey", right_on="n_nationkey", how="left"
)[nat_cols].fillna(0)
table_features["nation"] = nat_agg.values.astype(np.float32)
# ── Part: agrega via lineitem β†’ customer ──────────────────────────────────
part = tables["part"]
par_cols = [c for c in NUMERIC_COLS["part"] if c in part.columns]
par_with_cust = li_with_cust.merge(part, left_on="l_partkey", right_on="p_partkey", how="left")
par_mean = par_with_cust.groupby("o_custkey")[par_cols].mean()
par_agg = customers[["c_custkey"]].set_index("c_custkey").join(par_mean).fillna(0)
table_features["part"] = par_agg.values.astype(np.float32)
# Normaliza features (min-max por coluna)
for key in table_features:
feat = table_features[key]
col_min = feat.min(axis=0, keepdims=True)
col_max = feat.max(axis=0, keepdims=True)
denom = np.where((col_max - col_min) == 0, 1, col_max - col_min)
table_features[key] = (feat - col_min) / denom
return table_features, labels
# ─── TRAINER ─────────────────────────────────────────────────────────────────
class Trainer:
def __init__(self, model, config):
self.model = model
self.config = config
def fit(
self,
tables: Dict,
routes: List[AtomicRoute],
log_fn: Callable = print,
progress_fn=None,
) -> Tuple[Dict, List[Dict]]:
t_start = time.time()
H = self.config.hidden_dim
D = self.config.dropout
LR = self.config.learning_rate
EPOCHS = self.config.num_epochs
# 1. Extrai features
table_features_np, labels = extract_features(tables, len(tables["customers"]))
feature_dims = {k: v.shape[1] for k, v in table_features_np.items()}
# 2. Build do modelo (agora que sabemos as dims)
self.model.build(feature_dims, routes)
optimizer = optim.AdamW(self.model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# 3. Split treino/teste estratificado
n = len(labels)
idx = np.arange(n)
idx_tr, idx_te = train_test_split(idx, test_size=0.2, random_state=42,
stratify=(labels > 0.5).astype(int))
def to_tensor(feat_dict, idx):
return {k: torch.tensor(v[idx], dtype=torch.float32)
for k, v in feat_dict.items()}
y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32)
y_te = torch.tensor(labels[idx_te], dtype=torch.float32)
# Peso para classe positiva (fraude Γ© rara)
pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)])
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
history = []
log_interval = max(1, EPOCHS // 10)
self.model.train()
for epoch in range(1, EPOCHS + 1):
optimizer.zero_grad()
feat_tr = to_tensor(table_features_np, idx_tr)
logits, _ = self.model(feat_tr)
loss = loss_fn(logits, y_tr)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
if epoch % log_interval == 0 or epoch == EPOCHS:
self.model.eval()
with torch.no_grad():
feat_te = to_tensor(table_features_np, idx_te)
logits_te, _ = self.model(feat_te)
probs_te = torch.sigmoid(logits_te).numpy()
try:
auc = roc_auc_score(labels[idx_te], probs_te)
except Exception:
auc = 0.5
history.append({"epoch": epoch, "loss": float(loss), "auc": auc})
if epoch % (log_interval * 2) == 0 or epoch == EPOCHS:
log_fn(f" Γ‰poca {epoch:3d}/{EPOCHS} | Loss: {float(loss):.4f} | AUC: {auc:.4f}")
self.model.train()
if progress_fn:
pct = 0.30 + 0.35 * (epoch / EPOCHS)
progress_fn(pct, desc=f"RelGNN treino β€” Γ©poca {epoch}/{EPOCHS}")
# MΓ©tricas finais
self.model.eval()
with torch.no_grad():
feat_te = to_tensor(table_features_np, idx_te)
logits_te, attn_info = self.model(feat_te)
probs_te = torch.sigmoid(logits_te).numpy()
preds = (probs_te > 0.5).astype(int)
y_true = labels[idx_te].astype(int)
try:
auc = roc_auc_score(y_true, probs_te)
f1 = f1_score(y_true, preds, zero_division=0)
precision = precision_score(y_true, preds, zero_division=0)
recall = recall_score(y_true, preds, zero_division=0)
except Exception:
auc = f1 = precision = recall = 0.5
train_time = round(time.time() - t_start, 1)
metrics = {
"auc": round(auc, 4),
"f1": round(f1, 4),
"precision": round(precision, 4),
"recall": round(recall, 4),
"train_time": train_time,
}
# Atualiza pesos de atenΓ§Γ£o nas rotas com valores reais
route_weights = torch.softmax(self.model.hierarchical.route_weights, dim=0)
for i, route in enumerate(routes):
if i < len(route_weights):
route.attention_weight = float(route_weights[i].item())
route.active = route.attention_weight > 0.15
return metrics, history