Spaces:
Running
Running
VladGeekPro Copilot commited on
Commit ·
fb5dc97
1
Parent(s): 4f49b04
ChangedSupplierCandidatesRule
Browse filesCo-authored-by: Copilot <copilot@github.com>
- expense_predictor.py +32 -17
expense_predictor.py
CHANGED
|
@@ -52,13 +52,13 @@ def _train_global_model(
|
|
| 52 |
supplier_to_idx: dict,
|
| 53 |
user_to_idx: dict,
|
| 54 |
debug: bool = False,
|
| 55 |
-
) -> tuple[object | None, float,
|
| 56 |
"""Trains ONE global model on ALL records.
|
| 57 |
|
| 58 |
Each sample: (date, supplier_id, user_id, amount)
|
| 59 |
Features per row: [supplier_idx, user_idx, day, weekday, month,
|
| 60 |
rolling_mean_3 for supplier, rolling_mean_month for supplier]
|
| 61 |
-
Returns: (fitted_model, global_confidence,
|
| 62 |
"""
|
| 63 |
# Sort all samples by date to build rolling features correctly.
|
| 64 |
samples_sorted = sorted(samples, key=lambda s: s[0])
|
|
@@ -121,7 +121,7 @@ def _train_global_model(
|
|
| 121 |
user_supplier_last_sum[(user_id, supplier_id)] = amount
|
| 122 |
|
| 123 |
if len(X_all) < 10:
|
| 124 |
-
return None, 0.5,
|
| 125 |
|
| 126 |
X_fit, y_fit, X_val, y_val = _time_split_xy(X_all, y_all)
|
| 127 |
candidates = _build_candidates()
|
|
@@ -148,7 +148,7 @@ def _train_global_model(
|
|
| 148 |
best_model = model
|
| 149 |
|
| 150 |
if best_model is None:
|
| 151 |
-
return None, 0.5,
|
| 152 |
|
| 153 |
baseline_scale = max(1.0, statistics.mean([abs(v) for v in (y_val if y_val else y_fit)]))
|
| 154 |
global_conf = math.exp(-(best_mae / baseline_scale))
|
|
@@ -159,7 +159,7 @@ def _train_global_model(
|
|
| 159 |
f"avg_target={baseline_scale:.2f}, global_model_conf={global_conf:.2f}"
|
| 160 |
)
|
| 161 |
|
| 162 |
-
return best_model, max(0.0, min(1.0, global_conf)),
|
| 163 |
|
| 164 |
|
| 165 |
def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False) -> list[dict]:
|
|
@@ -189,18 +189,28 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
|
|
| 189 |
pct = count / total_records * 100
|
| 190 |
print(f"[PREDICT] supplier_id={supplier_id} -> {count} records ({pct:.1f}%)")
|
| 191 |
|
| 192 |
-
#
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
key=lambda item: supplier_freq[item[0]],
|
| 197 |
reverse=True,
|
| 198 |
-
)
|
| 199 |
|
| 200 |
if debug:
|
| 201 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
if not
|
| 204 |
if debug:
|
| 205 |
print("[PREDICT] No suppliers found. Returning empty.")
|
| 206 |
return []
|
|
@@ -224,7 +234,7 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
|
|
| 224 |
except Exception:
|
| 225 |
continue
|
| 226 |
|
| 227 |
-
global_model, global_model_conf,
|
| 228 |
all_samples, supplier_to_idx, user_to_idx, debug=debug
|
| 229 |
)
|
| 230 |
|
|
@@ -241,10 +251,10 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
|
|
| 241 |
user_supplier_last_date[(tx_user, tx_supplier)] = tx_date
|
| 242 |
user_supplier_last_sum[(tx_user, tx_supplier)] = tx_sum
|
| 243 |
|
| 244 |
-
# Predict
|
| 245 |
predictions = []
|
| 246 |
|
| 247 |
-
for supplier_id,
|
| 248 |
s_hist = supplier_amounts_sorted.get(supplier_id, [])
|
| 249 |
us_hist = user_supplier_amounts_sorted.get((target_user_id, supplier_id), [])
|
| 250 |
|
|
@@ -339,11 +349,16 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
|
|
| 339 |
"sum": round(max(0.0, predicted_amount), 2),
|
| 340 |
"supplier_id": supplier_id,
|
| 341 |
"user_id": predicted_user,
|
|
|
|
| 342 |
"confidence": round(confidence, 2)
|
| 343 |
})
|
| 344 |
|
| 345 |
-
# Return
|
| 346 |
-
result = sorted(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
if debug:
|
| 349 |
print(f"[PREDICT] Final top {len(result)} predictions:")
|
|
|
|
| 52 |
supplier_to_idx: dict,
|
| 53 |
user_to_idx: dict,
|
| 54 |
debug: bool = False,
|
| 55 |
+
) -> tuple[object | None, float, str]:
|
| 56 |
"""Trains ONE global model on ALL records.
|
| 57 |
|
| 58 |
Each sample: (date, supplier_id, user_id, amount)
|
| 59 |
Features per row: [supplier_idx, user_idx, day, weekday, month,
|
| 60 |
rolling_mean_3 for supplier, rolling_mean_month for supplier]
|
| 61 |
+
Returns: (fitted_model, global_confidence, model_name)
|
| 62 |
"""
|
| 63 |
# Sort all samples by date to build rolling features correctly.
|
| 64 |
samples_sorted = sorted(samples, key=lambda s: s[0])
|
|
|
|
| 121 |
user_supplier_last_sum[(user_id, supplier_id)] = amount
|
| 122 |
|
| 123 |
if len(X_all) < 10:
|
| 124 |
+
return None, 0.5, "fallback"
|
| 125 |
|
| 126 |
X_fit, y_fit, X_val, y_val = _time_split_xy(X_all, y_all)
|
| 127 |
candidates = _build_candidates()
|
|
|
|
| 148 |
best_model = model
|
| 149 |
|
| 150 |
if best_model is None:
|
| 151 |
+
return None, 0.5, "fallback"
|
| 152 |
|
| 153 |
baseline_scale = max(1.0, statistics.mean([abs(v) for v in (y_val if y_val else y_fit)]))
|
| 154 |
global_conf = math.exp(-(best_mae / baseline_scale))
|
|
|
|
| 159 |
f"avg_target={baseline_scale:.2f}, global_model_conf={global_conf:.2f}"
|
| 160 |
)
|
| 161 |
|
| 162 |
+
return best_model, max(0.0, min(1.0, global_conf)), best_name
|
| 163 |
|
| 164 |
|
| 165 |
def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False) -> list[dict]:
|
|
|
|
| 189 |
pct = count / total_records * 100
|
| 190 |
print(f"[PREDICT] supplier_id={supplier_id} -> {count} records ({pct:.1f}%)")
|
| 191 |
|
| 192 |
+
# Select suppliers whose frequency is strictly greater than 50% of the top supplier frequency.
|
| 193 |
+
max_freq = max(supplier_freq.values()) if supplier_freq else 0
|
| 194 |
+
freq_threshold = 0.5 * max_freq
|
| 195 |
+
candidate_items = [
|
| 196 |
+
item for item in supplier_history.items()
|
| 197 |
+
if supplier_freq[item[0]] > freq_threshold
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
# Keep candidates sorted by supplier usage frequency (desc).
|
| 201 |
+
candidate_items = sorted(
|
| 202 |
+
candidate_items,
|
| 203 |
key=lambda item: supplier_freq[item[0]],
|
| 204 |
reverse=True,
|
| 205 |
+
)
|
| 206 |
|
| 207 |
if debug:
|
| 208 |
+
print(
|
| 209 |
+
f"[PREDICT] Processing {len(candidate_items)} suppliers "
|
| 210 |
+
f"with freq > 50% of max ({freq_threshold:.2f})"
|
| 211 |
+
)
|
| 212 |
|
| 213 |
+
if not candidate_items:
|
| 214 |
if debug:
|
| 215 |
print("[PREDICT] No suppliers found. Returning empty.")
|
| 216 |
return []
|
|
|
|
| 234 |
except Exception:
|
| 235 |
continue
|
| 236 |
|
| 237 |
+
global_model, global_model_conf, model_name = _train_global_model(
|
| 238 |
all_samples, supplier_to_idx, user_to_idx, debug=debug
|
| 239 |
)
|
| 240 |
|
|
|
|
| 251 |
user_supplier_last_date[(tx_user, tx_supplier)] = tx_date
|
| 252 |
user_supplier_last_sum[(tx_user, tx_supplier)] = tx_sum
|
| 253 |
|
| 254 |
+
# Predict amount for each selected supplier.
|
| 255 |
predictions = []
|
| 256 |
|
| 257 |
+
for supplier_id, _records in candidate_items:
|
| 258 |
s_hist = supplier_amounts_sorted.get(supplier_id, [])
|
| 259 |
us_hist = user_supplier_amounts_sorted.get((target_user_id, supplier_id), [])
|
| 260 |
|
|
|
|
| 349 |
"sum": round(max(0.0, predicted_amount), 2),
|
| 350 |
"supplier_id": supplier_id,
|
| 351 |
"user_id": predicted_user,
|
| 352 |
+
"show": True,
|
| 353 |
"confidence": round(confidence, 2)
|
| 354 |
})
|
| 355 |
|
| 356 |
+
# Return all selected suppliers sorted by frequency desc.
|
| 357 |
+
result = sorted(
|
| 358 |
+
predictions,
|
| 359 |
+
key=lambda x: supplier_freq.get(x["supplier_id"], 0),
|
| 360 |
+
reverse=True,
|
| 361 |
+
)
|
| 362 |
|
| 363 |
if debug:
|
| 364 |
print(f"[PREDICT] Final top {len(result)} predictions:")
|