Create Xgboost baseline · py
Browse files- Xgboost baseline · py +134 -0
Xgboost baseline · py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
baseline/xgboost_baseline.py
|
| 3 |
+
Baseline XGBoost — features planas (flat features).
|
| 4 |
+
|
| 5 |
+
Agrega todas as tabelas em uma única linha por cliente e treina XGBoost.
|
| 6 |
+
Representa a abordagem clássica de ML sem estrutura relacional.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from sklearn.model_selection import train_test_split
|
| 13 |
+
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
|
| 14 |
+
from sklearn.ensemble import GradientBoostingClassifier
|
| 15 |
+
from typing import Dict, Callable
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class XGBoostBaseline:
|
| 19 |
+
"""
|
| 20 |
+
Usa GradientBoostingClassifier do scikit-learn (equivalente ao XGBoost)
|
| 21 |
+
para máxima compatibilidade no HF Spaces sem dependências extras.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, n_estimators: int = 100, max_depth: int = 4):
|
| 25 |
+
self.n_estimators = n_estimators
|
| 26 |
+
self.max_depth = max_depth
|
| 27 |
+
|
| 28 |
+
def _build_flat_features(self, tables: Dict) -> pd.DataFrame:
|
| 29 |
+
"""
|
| 30 |
+
Achata todas as tabelas em um DataFrame por cliente.
|
| 31 |
+
Engenharia de features manual — exatamente o que RelGNN evita.
|
| 32 |
+
"""
|
| 33 |
+
customers = tables["customers"]
|
| 34 |
+
orders = tables["orders"]
|
| 35 |
+
lineitem = tables["lineitem"]
|
| 36 |
+
supplier = tables["supplier"]
|
| 37 |
+
nation = tables["nation"]
|
| 38 |
+
|
| 39 |
+
feat = customers[["c_custkey", "c_acctbal", "c_nationkey",
|
| 40 |
+
"c_account_age_days", "c_num_prev_orders"]].copy()
|
| 41 |
+
|
| 42 |
+
# Agrega pedidos
|
| 43 |
+
ord_agg = orders.groupby("o_custkey").agg(
|
| 44 |
+
ord_count = ("o_orderkey", "count"),
|
| 45 |
+
ord_total_mean = ("o_totalprice", "mean"),
|
| 46 |
+
ord_total_max = ("o_totalprice", "max"),
|
| 47 |
+
ord_total_std = ("o_totalprice", "std"),
|
| 48 |
+
ord_priority_mean=("o_shippriority","mean"),
|
| 49 |
+
).reset_index().rename(columns={"o_custkey": "c_custkey"})
|
| 50 |
+
feat = feat.merge(ord_agg, on="c_custkey", how="left")
|
| 51 |
+
|
| 52 |
+
# Agrega linhas de pedido
|
| 53 |
+
li_with_cust = lineitem.merge(
|
| 54 |
+
orders[["o_orderkey","o_custkey"]], on="o_orderkey", how="left"
|
| 55 |
+
)
|
| 56 |
+
li_agg = li_with_cust.groupby("o_custkey").agg(
|
| 57 |
+
li_count = ("l_linenumber", "count"),
|
| 58 |
+
li_qty_mean = ("l_quantity", "mean"),
|
| 59 |
+
li_price_mean = ("l_extendedprice","mean"),
|
| 60 |
+
li_price_max = ("l_extendedprice","max"),
|
| 61 |
+
li_discount_mean= ("l_discount", "mean"),
|
| 62 |
+
li_tax_mean = ("l_tax", "mean"),
|
| 63 |
+
).reset_index().rename(columns={"o_custkey": "c_custkey"})
|
| 64 |
+
feat = feat.merge(li_agg, on="c_custkey", how="left")
|
| 65 |
+
|
| 66 |
+
# Agrega fornecedores via lineitem
|
| 67 |
+
sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey",
|
| 68 |
+
right_on="s_suppkey", how="left")
|
| 69 |
+
sup_agg = sup_with_cust.groupby("o_custkey").agg(
|
| 70 |
+
sup_acctbal_mean = ("s_acctbal", "mean"),
|
| 71 |
+
sup_risk_sum = ("s_risk_flag", "sum"),
|
| 72 |
+
sup_nation_nuniq = ("s_nationkey", "nunique"),
|
| 73 |
+
).reset_index().rename(columns={"o_custkey": "c_custkey"})
|
| 74 |
+
feat = feat.merge(sup_agg, on="c_custkey", how="left")
|
| 75 |
+
|
| 76 |
+
# Agrega nação
|
| 77 |
+
nat_agg = nation[["n_nationkey","n_regionkey"]].rename(
|
| 78 |
+
columns={"n_nationkey": "c_nationkey"}
|
| 79 |
+
)
|
| 80 |
+
feat = feat.merge(nat_agg, on="c_nationkey", how="left")
|
| 81 |
+
|
| 82 |
+
feat = feat.drop(columns=["c_custkey"], errors="ignore")
|
| 83 |
+
feat = feat.fillna(0)
|
| 84 |
+
|
| 85 |
+
return feat
|
| 86 |
+
|
| 87 |
+
def fit(self, tables: Dict, log_fn: Callable = print):
|
| 88 |
+
t_start = time.time()
|
| 89 |
+
log_fn(" [XGBoost] Construindo features planas (flat)...")
|
| 90 |
+
|
| 91 |
+
X = self._build_flat_features(tables)
|
| 92 |
+
|
| 93 |
+
# Labels
|
| 94 |
+
customers = tables["customers"]
|
| 95 |
+
orders = tables["orders"]
|
| 96 |
+
fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max()
|
| 97 |
+
y = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(int)
|
| 98 |
+
|
| 99 |
+
X_arr = X.values.astype(np.float32)
|
| 100 |
+
log_fn(f" [XGBoost] Shape features: {X_arr.shape}")
|
| 101 |
+
|
| 102 |
+
idx_tr, idx_te = train_test_split(
|
| 103 |
+
np.arange(len(y)), test_size=0.2, random_state=42,
|
| 104 |
+
stratify=y
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
model = GradientBoostingClassifier(
|
| 108 |
+
n_estimators=self.n_estimators,
|
| 109 |
+
max_depth=self.max_depth,
|
| 110 |
+
learning_rate=0.05,
|
| 111 |
+
subsample=0.8,
|
| 112 |
+
random_state=42,
|
| 113 |
+
)
|
| 114 |
+
model.fit(X_arr[idx_tr], y[idx_tr])
|
| 115 |
+
|
| 116 |
+
probs = model.predict_proba(X_arr[idx_te])[:, 1]
|
| 117 |
+
preds = (probs > 0.5).astype(int)
|
| 118 |
+
y_true = y[idx_te]
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
auc = roc_auc_score(y_true, probs)
|
| 122 |
+
f1 = f1_score(y_true, preds, zero_division=0)
|
| 123 |
+
precision = precision_score(y_true, preds, zero_division=0)
|
| 124 |
+
recall = recall_score(y_true, preds, zero_division=0)
|
| 125 |
+
except Exception:
|
| 126 |
+
auc = f1 = precision = recall = 0.5
|
| 127 |
+
|
| 128 |
+
train_time = round(time.time() - t_start, 1)
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"auc": round(auc, 4), "f1": round(f1, 4),
|
| 132 |
+
"precision": round(precision, 4), "recall": round(recall, 4),
|
| 133 |
+
"train_time": train_time,
|
| 134 |
+
}
|