VladGeekPro Copilot commited on
Commit
fb5dc97
·
1 Parent(s): 4f49b04

ChangedSupplierCandidatesRule

Browse files

Co-authored-by: Copilot <copilot@github.com>

Files changed (1) hide show
  1. expense_predictor.py +32 -17
expense_predictor.py CHANGED
@@ -52,13 +52,13 @@ def _train_global_model(
52
  supplier_to_idx: dict,
53
  user_to_idx: dict,
54
  debug: bool = False,
55
- ) -> tuple[object | None, float, float, str]:
56
  """Trains ONE global model on ALL records.
57
 
58
  Each sample: (date, supplier_id, user_id, amount)
59
  Features per row: [supplier_idx, user_idx, day, weekday, month,
60
  rolling_mean_3 for supplier, rolling_mean_month for supplier]
61
- Returns: (fitted_model, global_confidence, validation_mae, model_name)
62
  """
63
  # Sort all samples by date to build rolling features correctly.
64
  samples_sorted = sorted(samples, key=lambda s: s[0])
@@ -121,7 +121,7 @@ def _train_global_model(
121
  user_supplier_last_sum[(user_id, supplier_id)] = amount
122
 
123
  if len(X_all) < 10:
124
- return None, 0.5, float("inf"), "fallback"
125
 
126
  X_fit, y_fit, X_val, y_val = _time_split_xy(X_all, y_all)
127
  candidates = _build_candidates()
@@ -148,7 +148,7 @@ def _train_global_model(
148
  best_model = model
149
 
150
  if best_model is None:
151
- return None, 0.5, float("inf"), "fallback"
152
 
153
  baseline_scale = max(1.0, statistics.mean([abs(v) for v in (y_val if y_val else y_fit)]))
154
  global_conf = math.exp(-(best_mae / baseline_scale))
@@ -159,7 +159,7 @@ def _train_global_model(
159
  f"avg_target={baseline_scale:.2f}, global_model_conf={global_conf:.2f}"
160
  )
161
 
162
- return best_model, max(0.0, min(1.0, global_conf)), best_mae, best_name
163
 
164
 
165
  def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False) -> list[dict]:
@@ -189,18 +189,28 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
189
  pct = count / total_records * 100
190
  print(f"[PREDICT] supplier_id={supplier_id} -> {count} records ({pct:.1f}%)")
191
 
192
- # Keep only top 3 suppliers by frequency (different suppliers)
193
- candidates = supplier_history
194
- top_candidate_items = sorted(
195
- candidates.items(),
 
 
 
 
 
 
 
196
  key=lambda item: supplier_freq[item[0]],
197
  reverse=True,
198
- )[:3]
199
 
200
  if debug:
201
- print(f"[PREDICT] Processing top {len(top_candidate_items)} suppliers by frequency")
 
 
 
202
 
203
- if not top_candidate_items:
204
  if debug:
205
  print("[PREDICT] No suppliers found. Returning empty.")
206
  return []
@@ -224,7 +234,7 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
224
  except Exception:
225
  continue
226
 
227
- global_model, global_model_conf, val_mae, model_name = _train_global_model(
228
  all_samples, supplier_to_idx, user_to_idx, debug=debug
229
  )
230
 
@@ -241,10 +251,10 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
241
  user_supplier_last_date[(tx_user, tx_supplier)] = tx_date
242
  user_supplier_last_sum[(tx_user, tx_supplier)] = tx_sum
243
 
244
- # Predict only amount for each of top-3 suppliers.
245
  predictions = []
246
 
247
- for supplier_id, records in top_candidate_items:
248
  s_hist = supplier_amounts_sorted.get(supplier_id, [])
249
  us_hist = user_supplier_amounts_sorted.get((target_user_id, supplier_id), [])
250
 
@@ -339,11 +349,16 @@ def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False)
339
  "sum": round(max(0.0, predicted_amount), 2),
340
  "supplier_id": supplier_id,
341
  "user_id": predicted_user,
 
342
  "confidence": round(confidence, 2)
343
  })
344
 
345
- # Return top 3 by confidence
346
- result = sorted(predictions, key=lambda x: x["confidence"], reverse=True)[:3]
 
 
 
 
347
 
348
  if debug:
349
  print(f"[PREDICT] Final top {len(result)} predictions:")
 
52
  supplier_to_idx: dict,
53
  user_to_idx: dict,
54
  debug: bool = False,
55
+ ) -> tuple[object | None, float, str]:
56
  """Trains ONE global model on ALL records.
57
 
58
  Each sample: (date, supplier_id, user_id, amount)
59
  Features per row: [supplier_idx, user_idx, day, weekday, month,
60
  rolling_mean_3 for supplier, rolling_mean_month for supplier]
61
+ Returns: (fitted_model, global_confidence, model_name)
62
  """
63
  # Sort all samples by date to build rolling features correctly.
64
  samples_sorted = sorted(samples, key=lambda s: s[0])
 
121
  user_supplier_last_sum[(user_id, supplier_id)] = amount
122
 
123
  if len(X_all) < 10:
124
+ return None, 0.5, "fallback"
125
 
126
  X_fit, y_fit, X_val, y_val = _time_split_xy(X_all, y_all)
127
  candidates = _build_candidates()
 
148
  best_model = model
149
 
150
  if best_model is None:
151
+ return None, 0.5, "fallback"
152
 
153
  baseline_scale = max(1.0, statistics.mean([abs(v) for v in (y_val if y_val else y_fit)]))
154
  global_conf = math.exp(-(best_mae / baseline_scale))
 
159
  f"avg_target={baseline_scale:.2f}, global_model_conf={global_conf:.2f}"
160
  )
161
 
162
+ return best_model, max(0.0, min(1.0, global_conf)), best_name
163
 
164
 
165
  def predict_expenses(expenses: list[dict], target_user_id, debug: bool = False) -> list[dict]:
 
189
  pct = count / total_records * 100
190
  print(f"[PREDICT] supplier_id={supplier_id} -> {count} records ({pct:.1f}%)")
191
 
192
+ # Select suppliers whose frequency is strictly greater than 50% of the top supplier frequency.
193
+ max_freq = max(supplier_freq.values()) if supplier_freq else 0
194
+ freq_threshold = 0.5 * max_freq
195
+ candidate_items = [
196
+ item for item in supplier_history.items()
197
+ if supplier_freq[item[0]] > freq_threshold
198
+ ]
199
+
200
+ # Keep candidates sorted by supplier usage frequency (desc).
201
+ candidate_items = sorted(
202
+ candidate_items,
203
  key=lambda item: supplier_freq[item[0]],
204
  reverse=True,
205
+ )
206
 
207
  if debug:
208
+ print(
209
+ f"[PREDICT] Processing {len(candidate_items)} suppliers "
210
+ f"with freq > 50% of max ({freq_threshold:.2f})"
211
+ )
212
 
213
+ if not candidate_items:
214
  if debug:
215
  print("[PREDICT] No suppliers found. Returning empty.")
216
  return []
 
234
  except Exception:
235
  continue
236
 
237
+ global_model, global_model_conf, model_name = _train_global_model(
238
  all_samples, supplier_to_idx, user_to_idx, debug=debug
239
  )
240
 
 
251
  user_supplier_last_date[(tx_user, tx_supplier)] = tx_date
252
  user_supplier_last_sum[(tx_user, tx_supplier)] = tx_sum
253
 
254
+ # Predict amount for each selected supplier.
255
  predictions = []
256
 
257
+ for supplier_id, _records in candidate_items:
258
  s_hist = supplier_amounts_sorted.get(supplier_id, [])
259
  us_hist = user_supplier_amounts_sorted.get((target_user_id, supplier_id), [])
260
 
 
349
  "sum": round(max(0.0, predicted_amount), 2),
350
  "supplier_id": supplier_id,
351
  "user_id": predicted_user,
352
+ "show": True,
353
  "confidence": round(confidence, 2)
354
  })
355
 
356
+ # Return all selected suppliers sorted by frequency desc.
357
+ result = sorted(
358
+ predictions,
359
+ key=lambda x: supplier_freq.get(x["supplier_id"], 0),
360
+ reverse=True,
361
+ )
362
 
363
  if debug:
364
  print(f"[PREDICT] Final top {len(result)} predictions:")