Danielfonseca1212 commited on
Commit
22b1610
·
verified ·
1 Parent(s): d27f646

Create Xgboost baseline · py

Browse files
Files changed (1) hide show
  1. Xgboost baseline · py +134 -0
Xgboost baseline · py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ baseline/xgboost_baseline.py
3
+ Baseline XGBoost — features planas (flat features).
4
+
5
+ Agrega todas as tabelas em uma única linha por cliente e treina XGBoost.
6
+ Representa a abordagem clássica de ML sem estrutura relacional.
7
+ """
8
+
9
+ import time
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.model_selection import train_test_split
13
+ from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
14
+ from sklearn.ensemble import GradientBoostingClassifier
15
+ from typing import Dict, Callable
16
+
17
+
18
+ class XGBoostBaseline:
19
+ """
20
+ Usa GradientBoostingClassifier do scikit-learn (equivalente ao XGBoost)
21
+ para máxima compatibilidade no HF Spaces sem dependências extras.
22
+ """
23
+
24
+ def __init__(self, n_estimators: int = 100, max_depth: int = 4):
25
+ self.n_estimators = n_estimators
26
+ self.max_depth = max_depth
27
+
28
+ def _build_flat_features(self, tables: Dict) -> pd.DataFrame:
29
+ """
30
+ Achata todas as tabelas em um DataFrame por cliente.
31
+ Engenharia de features manual — exatamente o que RelGNN evita.
32
+ """
33
+ customers = tables["customers"]
34
+ orders = tables["orders"]
35
+ lineitem = tables["lineitem"]
36
+ supplier = tables["supplier"]
37
+ nation = tables["nation"]
38
+
39
+ feat = customers[["c_custkey", "c_acctbal", "c_nationkey",
40
+ "c_account_age_days", "c_num_prev_orders"]].copy()
41
+
42
+ # Agrega pedidos
43
+ ord_agg = orders.groupby("o_custkey").agg(
44
+ ord_count = ("o_orderkey", "count"),
45
+ ord_total_mean = ("o_totalprice", "mean"),
46
+ ord_total_max = ("o_totalprice", "max"),
47
+ ord_total_std = ("o_totalprice", "std"),
48
+ ord_priority_mean=("o_shippriority","mean"),
49
+ ).reset_index().rename(columns={"o_custkey": "c_custkey"})
50
+ feat = feat.merge(ord_agg, on="c_custkey", how="left")
51
+
52
+ # Agrega linhas de pedido
53
+ li_with_cust = lineitem.merge(
54
+ orders[["o_orderkey","o_custkey"]], on="o_orderkey", how="left"
55
+ )
56
+ li_agg = li_with_cust.groupby("o_custkey").agg(
57
+ li_count = ("l_linenumber", "count"),
58
+ li_qty_mean = ("l_quantity", "mean"),
59
+ li_price_mean = ("l_extendedprice","mean"),
60
+ li_price_max = ("l_extendedprice","max"),
61
+ li_discount_mean= ("l_discount", "mean"),
62
+ li_tax_mean = ("l_tax", "mean"),
63
+ ).reset_index().rename(columns={"o_custkey": "c_custkey"})
64
+ feat = feat.merge(li_agg, on="c_custkey", how="left")
65
+
66
+ # Agrega fornecedores via lineitem
67
+ sup_with_cust = li_with_cust.merge(supplier, left_on="l_suppkey",
68
+ right_on="s_suppkey", how="left")
69
+ sup_agg = sup_with_cust.groupby("o_custkey").agg(
70
+ sup_acctbal_mean = ("s_acctbal", "mean"),
71
+ sup_risk_sum = ("s_risk_flag", "sum"),
72
+ sup_nation_nuniq = ("s_nationkey", "nunique"),
73
+ ).reset_index().rename(columns={"o_custkey": "c_custkey"})
74
+ feat = feat.merge(sup_agg, on="c_custkey", how="left")
75
+
76
+ # Agrega nação
77
+ nat_agg = nation[["n_nationkey","n_regionkey"]].rename(
78
+ columns={"n_nationkey": "c_nationkey"}
79
+ )
80
+ feat = feat.merge(nat_agg, on="c_nationkey", how="left")
81
+
82
+ feat = feat.drop(columns=["c_custkey"], errors="ignore")
83
+ feat = feat.fillna(0)
84
+
85
+ return feat
86
+
87
+ def fit(self, tables: Dict, log_fn: Callable = print):
88
+ t_start = time.time()
89
+ log_fn(" [XGBoost] Construindo features planas (flat)...")
90
+
91
+ X = self._build_flat_features(tables)
92
+
93
+ # Labels
94
+ customers = tables["customers"]
95
+ orders = tables["orders"]
96
+ fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max()
97
+ y = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(int)
98
+
99
+ X_arr = X.values.astype(np.float32)
100
+ log_fn(f" [XGBoost] Shape features: {X_arr.shape}")
101
+
102
+ idx_tr, idx_te = train_test_split(
103
+ np.arange(len(y)), test_size=0.2, random_state=42,
104
+ stratify=y
105
+ )
106
+
107
+ model = GradientBoostingClassifier(
108
+ n_estimators=self.n_estimators,
109
+ max_depth=self.max_depth,
110
+ learning_rate=0.05,
111
+ subsample=0.8,
112
+ random_state=42,
113
+ )
114
+ model.fit(X_arr[idx_tr], y[idx_tr])
115
+
116
+ probs = model.predict_proba(X_arr[idx_te])[:, 1]
117
+ preds = (probs > 0.5).astype(int)
118
+ y_true = y[idx_te]
119
+
120
+ try:
121
+ auc = roc_auc_score(y_true, probs)
122
+ f1 = f1_score(y_true, preds, zero_division=0)
123
+ precision = precision_score(y_true, preds, zero_division=0)
124
+ recall = recall_score(y_true, preds, zero_division=0)
125
+ except Exception:
126
+ auc = f1 = precision = recall = 0.5
127
+
128
+ train_time = round(time.time() - t_start, 1)
129
+
130
+ return {
131
+ "auc": round(auc, 4), "f1": round(f1, 4),
132
+ "precision": round(precision, 4), "recall": round(recall, 4),
133
+ "train_time": train_time,
134
+ }