LogicGoInfotechSpaces commited on
Commit
5216e16
·
verified ·
1 Parent(s): 417c49f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +752 -0
app.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import calendar
3
+ import math
4
+ import os
5
+ from collections import defaultdict
6
+ from datetime import datetime, timezone
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ from bson import ObjectId
10
+ from dotenv import load_dotenv
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel, Field
13
+ from pymongo import MongoClient
14
+ from pymongo.collection import Collection
15
+
16
+ load_dotenv()
17
+
18
+ app = FastAPI(title="Expense Prediction API", version="1.0.0")
19
+
20
+ # ---------- Configurable constants ----------
21
+ MAX_HISTORY_MONTHS = int(os.getenv("MAX_HISTORY_MONTHS", "36")) # months to fetch for detection/tuning
22
+ SEASONALITY_PERIOD = int(os.getenv("SEASONALITY_PERIOD", "12")) # monthly seasonality (12 months)
23
+ SEASONALITY_AMPLITUDE_THRESHOLD = float(os.getenv("SEASONALITY_AMPLITUDE_THRESHOLD", "0.18"))
24
+ # grid-search limits (keeps tuning light)
25
+ ALPHA_GRID = [0.3, 0.5, 0.7]
26
+ BETA_GRID = [0.1, 0.3, 0.5]
27
+ GAMMA_GRID = [0.1, 0.3, 0.5]
28
+ MAX_GRID_SEARCH_COMBINATIONS = 30 # safety cap
29
+ # ------------------------------------------------
30
+
31
+ class MonthlyExpense(BaseModel):
32
+ year: int
33
+ month: int
34
+ total: float = Field(..., description="Total expenses recorded for the month")
35
+
36
+
37
+ class CategoryPrediction(BaseModel):
38
+ headCategoryId: str
39
+ title: str
40
+ history: List[MonthlyExpense]
41
+ predictionMonth: MonthlyExpense
42
+
43
+
44
+ class PredictionResponse(BaseModel):
45
+ userId: str
46
+ categories: List[CategoryPrediction]
47
+
48
+
49
+ class MongoConnection:
50
+ def __init__(self) -> None:
51
+ mongo_uri = os.getenv("MONGO_URI")
52
+ if not mongo_uri:
53
+ raise RuntimeError("MONGO_URI is not configured in the environment")
54
+
55
+ self._client = MongoClient(mongo_uri, tz_aware=True)
56
+ self._database = self._client.get_default_database()
57
+ self.transactions: Collection = self._database["transactions"]
58
+ self.headcategories: Collection = self._database["headcategories"]
59
+
60
+
61
+ mongo = MongoConnection()
62
+
63
+ # ----------------- Date helpers -----------------
64
+ def _first_day_of_month(dt: datetime) -> datetime:
65
+ return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
66
+
67
+
68
+ def _shift_months(dt: datetime, months: int) -> datetime:
69
+ month_index = dt.month - 1 + months
70
+ year = dt.year + month_index // 12
71
+ month = month_index % 12 + 1
72
+ last_day = calendar.monthrange(year, month)[1]
73
+ day = min(dt.day, last_day)
74
+ return dt.replace(year=year, month=month, day=day)
75
+
76
+
77
+ def month_to_index(year: int, month: int) -> int:
78
+ return year * 12 + (month - 1)
79
+
80
+
81
+ def index_to_month(idx: int) -> Tuple[int, int]:
82
+ year = idx // 12
83
+ month = (idx % 12) + 1
84
+ return year, month
85
+ # ------------------------------------------------
86
+
87
+ # ----------------- Time series utilities -----------------
88
+ def build_continuous_series(history: List[MonthlyExpense]) -> Tuple[List[float], List[Tuple[int, int]]]:
89
+ """
90
+ Given sparse monthly history items (year, month, total), build a continuous series
91
+ covering from earliest to latest month in history. Missing months are represented by None.
92
+ Returns (values_list_with_none, list_of_(year,month)_corresponding).
93
+ """
94
+ if not history:
95
+ return [], []
96
+
97
+ # sort history
98
+ history_sorted = sorted(history, key=lambda h: (h.year, h.month))
99
+ start_idx = month_to_index(history_sorted[0].year, history_sorted[0].month)
100
+ end_idx = month_to_index(history_sorted[-1].year, history_sorted[-1].month)
101
+ length = end_idx - start_idx + 1
102
+
103
+ idx_to_val = {}
104
+ for h in history_sorted:
105
+ idx = month_to_index(h.year, h.month)
106
+ idx_to_val[idx] = h.total
107
+
108
+ series = []
109
+ months = []
110
+ for i in range(start_idx, end_idx + 1):
111
+ months.append(index_to_month(i))
112
+ series.append(idx_to_val.get(i, None))
113
+
114
+ return series, months
115
+
116
+
117
+ def impute_missing(series: List[Optional[float]]) -> List[float]:
118
+ """
119
+ Fill missing values (None) by linear interpolation. If leading/trailing Nones remain,
120
+ forward/backfill with nearest value or 0 if no data.
121
+ """
122
+ n = len(series)
123
+ if n == 0:
124
+ return []
125
+
126
+ arr = [None if v is None else float(v) for v in series]
127
+
128
+ # collect indices of non-None
129
+ known = [i for i, v in enumerate(arr) if v is not None]
130
+
131
+ if not known:
132
+ # all missing -> return zeros
133
+ return [0.0] * n
134
+
135
+ # linear interpolation between known points
136
+ for i in range(len(known) - 1):
137
+ a = known[i]
138
+ b = known[i + 1]
139
+ va = arr[a]
140
+ vb = arr[b]
141
+ step = (vb - va) / (b - a)
142
+ for j in range(a + 1, b):
143
+ arr[j] = va + step * (j - a)
144
+
145
+ # fill leading
146
+ first = known[0]
147
+ for i in range(0, first):
148
+ arr[i] = arr[first]
149
+
150
+ # fill trailing
151
+ last = known[-1]
152
+ for i in range(last + 1, n):
153
+ arr[i] = arr[last]
154
+
155
+ return [float(x) for x in arr]
156
+
157
+
158
+ def seasonal_strength(series: List[float], period: int = SEASONALITY_PERIOD) -> float:
159
+ """
160
+ Estimate seasonality strength for monthly data.
161
+ Returns amplitude_ratio = (max_month_mean - min_month_mean) / overall_mean
162
+ Higher value => stronger seasonality.
163
+ Requires at least 2 * period data points for a reliable estimate.
164
+ """
165
+ n = len(series)
166
+ if n < 2 * period:
167
+ return 0.0
168
+
169
+ # compute month-of-year means
170
+ month_buckets = [[] for _ in range(period)]
171
+ for idx, val in enumerate(series):
172
+ month = idx % period
173
+ month_buckets[month].append(val)
174
+
175
+ month_means = [ (sum(b)/len(b)) if b else 0.0 for b in month_buckets ]
176
+ overall_mean = sum(series) / len(series) if series else 0.0
177
+ if overall_mean == 0:
178
+ return 0.0
179
+ amplitude = max(month_means) - min(month_means)
180
+ return amplitude / overall_mean
181
+
182
+
183
+ # ----------------- Forecasting algorithms -----------------
184
+ def holt_double_forecast(series: List[float], alpha: float, beta: float, n_forecast: int = 1) -> List[float]:
185
+ """
186
+ Holt's linear method (double exponential smoothing).
187
+ Returns list of length n_forecast (forecast ahead).
188
+ """
189
+ n = len(series)
190
+ if n == 0:
191
+ return [0.0] * n_forecast
192
+ if n == 1:
193
+ return [series[-1]] * n_forecast
194
+
195
+ level = series[0]
196
+ trend = series[1] - series[0]
197
+
198
+ for t in range(1, n):
199
+ value = series[t]
200
+ prev_level = level
201
+ level = alpha * value + (1 - alpha) * (level + trend)
202
+ trend = beta * (level - prev_level) + (1 - beta) * trend
203
+
204
+ # forecast h steps ahead
205
+ forecasts = [level + (i + 1) * trend for i in range(n_forecast)]
206
+ return [max(0.0, f) for f in forecasts]
207
+
208
+
209
+ def holt_winters_additive(series: List[float], season_length: int, alpha: float, beta: float, gamma: float, n_forecast: int = 1) -> List[float]:
210
+ """
211
+ Additive Holt-Winters seasonal method.
212
+ series: list of floats (no missing) where season_length is known (e.g., 12)
213
+ """
214
+ n = len(series)
215
+ if n == 0:
216
+ return [0.0] * n_forecast
217
+ if n < season_length * 2:
218
+ # not enough data to initialize seasonals reliably -> fallback to holt_double
219
+ return holt_double_forecast(series, alpha, beta, n_forecast)
220
+
221
+ # initialize level, trend, seasonals
222
+ seasonals = _initial_seasonal_components(series, season_length)
223
+ level = sum(series[:season_length]) / season_length
224
+ trend = (sum(series[season_length:2*season_length]) - sum(series[:season_length])) / (season_length * season_length)
225
+
226
+ result = []
227
+ for i in range(n + n_forecast):
228
+ if i < n:
229
+ val = series[i]
230
+ last_level = level
231
+ level = alpha * (val - seasonals[i % season_length]) + (1 - alpha) * (level + trend)
232
+ trend = beta * (level - last_level) + (1 - beta) * trend
233
+ seasonals[i % season_length] = gamma * (val - level) + (1 - gamma) * seasonals[i % season_length]
234
+ # in-sample prediction (not used)
235
+ else:
236
+ # forecast
237
+ m = i - n + 1
238
+ forecast = level + m * trend + seasonals[i % season_length]
239
+ result.append(max(0.0, forecast))
240
+
241
+ # ensure length matches n_forecast
242
+ return result[:n_forecast]
243
+
244
+
245
+ def _initial_seasonal_components(series: List[float], season_length: int) -> List[float]:
246
+ """
247
+ Initialize seasonality components by averaging.
248
+ """
249
+ seasonals = [0.0] * season_length
250
+ n_seasons = len(series) // season_length
251
+ if n_seasons == 0:
252
+ return seasonals
253
+ season_averages = []
254
+ for j in range(n_seasons):
255
+ start = j * season_length
256
+ season_avg = sum(series[start:start + season_length]) / season_length
257
+ season_averages.append(season_avg)
258
+ for i in range(season_length):
259
+ s = 0.0
260
+ for j in range(n_seasons):
261
+ s += series[j * season_length + i] - season_averages[j]
262
+ seasonals[i] = s / n_seasons
263
+ return seasonals
264
+
265
+ # ----------------- Dynamic WMA -----------------
266
+ def dynamic_wma(series: List[float], max_len: int = 6) -> float:
267
+ """
268
+ Compute a dynamic WMA using up to max_len most recent months.
269
+ The weights adapt based on volatility: higher volatility -> smoother (older months get more weight).
270
+ """
271
+ n = len(series)
272
+ if n == 0:
273
+ return 0.0
274
+ take = min(n, max_len)
275
+ recent = series[-take:]
276
+ # compute month-to-month relative changes
277
+ if len(recent) >= 2:
278
+ changes = [abs(recent[i] - recent[i - 1]) for i in range(1, len(recent))]
279
+ vol = sum(changes) / len(changes) if changes else 0.0
280
+ else:
281
+ vol = 0.0
282
+
283
+ # base weights favor recent months
284
+ base_weights = [ (i + 1) for i in range(take) ] # 1..take
285
+ base_weights = list(reversed(base_weights)) # newest highest
286
+ total = sum(base_weights)
287
+ base_weights = [w/total for w in base_weights]
288
+
289
+ # adaptation factor: more vol -> flatten weights
290
+ # vol_ratio normalized roughly w.r.t average magnitude
291
+ avg = sum(recent) / len(recent) if recent else 1.0
292
+ vol_ratio = (vol / avg) if avg else 0.0
293
+ # clamp vol_ratio
294
+ vol_ratio = max(0.0, min(vol_ratio, 1.0))
295
+
296
+ # blend between base_weights and equal weights
297
+ equal_weights = [1.0 / take] * take
298
+ blend = min(0.7, vol_ratio) # limit blend to avoid extreme flattening
299
+ weights = [(1 - blend) * bw + blend * ew for bw, ew in zip(base_weights, equal_weights)]
300
+ # compute prediction
301
+ prediction = sum(w * v for w, v in zip(weights, reversed(recent))) # reversed so weights map newest->oldest
302
+ return max(0.0, prediction)
303
+
304
+ # ----------------- Parameter tuning (lightweight) -----------------
305
+ def walk_forward_cv_mse(series: List[float], forecast_func, params: dict, min_train_size: int = 6) -> float:
306
+ """
307
+ Perform walk-forward validation computing MSE. forecast_func must accept (train_series, params) and return a single-step forecast.
308
+ """
309
+ n = len(series)
310
+ if n < min_train_size + 1:
311
+ # not enough data to validate -> return large error so tuner avoids complex models
312
+ return float("inf")
313
+
314
+ errors = []
315
+ # iterate rolling window
316
+ for split in range(min_train_size, n):
317
+ train = series[:split]
318
+ actual = series[split]
319
+ try:
320
+ pred = forecast_func(train, params)
321
+ except Exception:
322
+ return float("inf")
323
+ if pred is None:
324
+ return float("inf")
325
+ errors.append((pred - actual) ** 2)
326
+ return sum(errors) / len(errors) if errors else float("inf")
327
+
328
+
329
+ def forecast_wrapper_holt(train: List[float], params: dict) -> float:
330
+ alpha = params.get("alpha", 0.5)
331
+ beta = params.get("beta", 0.3)
332
+ return holt_double_forecast(train, alpha, beta, n_forecast=1)[0]
333
+
334
+
335
+ def forecast_wrapper_hw(train: List[float], params: dict) -> float:
336
+ alpha = params.get("alpha", 0.5)
337
+ beta = params.get("beta", 0.3)
338
+ gamma = params.get("gamma", 0.2)
339
+ season_length = params.get("season_length", SEASONALITY_PERIOD)
340
+ return holt_winters_additive(train, season_length, alpha, beta, gamma, n_forecast=1)[0]
341
+
342
+
343
+ def tune_parameters(series: List[float], seasonal: bool, season_length: int = SEASONALITY_PERIOD) -> dict:
344
+ """
345
+ Lightweight grid search for (alpha, beta, gamma) returning best params.
346
+ Uses walk-forward CV to score parameter combinations.
347
+ """
348
+ best = None
349
+ best_score = float("inf")
350
+ combos_tested = 0
351
+
352
+ if seasonal:
353
+ grid = []
354
+ for a in ALPHA_GRID:
355
+ for b in BETA_GRID:
356
+ for g in GAMMA_GRID:
357
+ grid.append({"alpha": a, "beta": b, "gamma": g, "season_length": season_length})
358
+ else:
359
+ grid = [{"alpha": a, "beta": b} for a in ALPHA_GRID for b in BETA_GRID]
360
+
361
+ # cap combos
362
+ if len(grid) > MAX_GRID_SEARCH_COMBINATIONS:
363
+ grid = grid[:MAX_GRID_SEARCH_COMBINATIONS]
364
+
365
+ for params in grid:
366
+ combos_tested += 1
367
+ if seasonal:
368
+ score = walk_forward_cv_mse(series, forecast_wrapper_hw, params, min_train_size=max(6, season_length))
369
+ else:
370
+ score = walk_forward_cv_mse(series, forecast_wrapper_holt, params, min_train_size=6)
371
+ if score < best_score:
372
+ best_score = score
373
+ best = params
374
+
375
+ if best is None:
376
+ # fallback default
377
+ if seasonal:
378
+ return {"alpha": 0.5, "beta": 0.3, "gamma": 0.2, "season_length": season_length}
379
+ else:
380
+ return {"alpha": 0.5, "beta": 0.3}
381
+
382
+ return best
383
+
384
+ # ----------------- Top-level predictor combining everything -----------------
385
+ def _predict_next_month(history: List[MonthlyExpense]) -> float:
386
+ """
387
+ Comprehensive predictor:
388
+ - builds continuous series and imputes missing months
389
+ - auto-detects seasonality
390
+ - tunes parameters (lightweight) per series
391
+ - uses Holt-Winters if seasonal, else Holt
392
+ - fallback to dynamic WMA for very short/noisy series
393
+ """
394
+ if not history:
395
+ return 0.0
396
+
397
+ # limit history length to MAX_HISTORY_MONTHS (use most recent months)
398
+ history_sorted = sorted(history, key=lambda h: (h.year, h.month))
399
+ if len(history_sorted) > MAX_HISTORY_MONTHS:
400
+ history_sorted = history_sorted[-MAX_HISTORY_MONTHS:]
401
+
402
+ # Build continuous series (may contain Nones for missing months)
403
+ series_with_none, months = build_continuous_series(history_sorted)
404
+ series = impute_missing(series_with_none)
405
+
406
+ # if after imputation all zeros, return 0
407
+ if all(v == 0.0 for v in series):
408
+ return 0.0
409
+
410
+ n = len(series)
411
+
412
+ # If very short history (<=2) use simple rules / dynamic WMA
413
+ if n <= 2:
414
+ return round(dynamic_wma(series, max_len=2), 2)
415
+
416
+ # Seasonality detection: needs at least 2 * season_length samples for reliability
417
+ season_strength = seasonal_strength(series, period=SEASONALITY_PERIOD)
418
+ is_seasonal = season_strength >= SEASONALITY_AMPLITUDE_THRESHOLD and n >= 2 * SEASONALITY_PERIOD
419
+
420
+ # If not much data but still some seasonality signal present and we have at least season_length points,
421
+ # we can still attempt seasonal HW but with care.
422
+ season_length_used = SEASONALITY_PERIOD if is_seasonal else None
423
+
424
+ # Tuning: per-series personalized coefficients
425
+ try:
426
+ tuned = tune_parameters(series, seasonal=is_seasonal, season_length=season_length_used or SEASONALITY_PERIOD)
427
+ except Exception:
428
+ tuned = None
429
+
430
+ # If tuning failed or not enough data, fallback defaults
431
+ if tuned is None:
432
+ if is_seasonal:
433
+ tuned = {"alpha": 0.5, "beta": 0.3, "gamma": 0.2, "season_length": SEASONALITY_PERIOD}
434
+ else:
435
+ tuned = {"alpha": 0.5, "beta": 0.3}
436
+
437
+ # Edge case: if the series is extremely volatile compared to mean, prefer dynamic WMA (more robust)
438
+ mean_val = sum(series) / len(series) if series else 0.0
439
+ diffs = [abs(series[i] - series[i - 1]) for i in range(1, len(series))] if len(series) >= 2 else [0.0]
440
+ avg_diff = sum(diffs) / len(diffs) if diffs else 0.0
441
+ volatility_ratio = (avg_diff / mean_val) if mean_val else 0.0
442
+
443
+ if volatility_ratio > 1.0 and n < 6:
444
+ # extremely volatile and short history -> WMA is safer
445
+ pred = dynamic_wma(series, max_len=min(6, n))
446
+ return round(pred, 2)
447
+
448
+ # Choose model
449
+ if is_seasonal:
450
+ alpha = tuned.get("alpha", 0.5)
451
+ beta = tuned.get("beta", 0.3)
452
+ gamma = tuned.get("gamma", 0.2)
453
+ season_length = tuned.get("season_length", SEASONALITY_PERIOD)
454
+ pred = holt_winters_additive(series, season_length, alpha, beta, gamma, n_forecast=1)[0]
455
+ else:
456
+ alpha = tuned.get("alpha", 0.5)
457
+ beta = tuned.get("beta", 0.3)
458
+ pred = holt_double_forecast(series, alpha, beta, n_forecast=1)[0]
459
+
460
+ # final safety clamps
461
+ if math.isnan(pred) or pred is None or pred < 0:
462
+ # fallback to recent avg
463
+ pred = sum(series[-3:]) / min(3, len(series))
464
+
465
+ return round(float(pred), 2)
466
+
467
+
468
+ # ----------------- API endpoint -----------------
469
+ @app.get("/users/{user_id}/expense-prediction", response_model=PredictionResponse)
470
+ def predict_expense(user_id: str) -> PredictionResponse:
471
+ try:
472
+ user_object_id = ObjectId(user_id)
473
+ except Exception as exc:
474
+ raise HTTPException(status_code=400, detail="Invalid user id") from exc
475
+
476
+ now = datetime.now(timezone.utc)
477
+ # fetch up to MAX_HISTORY_MONTHS of history
478
+ start_period = _shift_months(_first_day_of_month(now), -MAX_HISTORY_MONTHS + 1)
479
+ prediction_month = _shift_months(_first_day_of_month(now), 1)
480
+
481
+ pipeline = [
482
+ {
483
+ "$match": {
484
+ "user": user_object_id,
485
+ "type": "EXPENSE",
486
+ "headCategory": {"$ne": None},
487
+ "date": {"$gte": start_period},
488
+ }
489
+ },
490
+ {
491
+ "$project": {
492
+ "amount": 1,
493
+ "headCategory": 1,
494
+ "year": {"$year": "$date"},
495
+ "month": {"$month": "$date"},
496
+ }
497
+ },
498
+ {
499
+ "$group": {
500
+ "_id": {
501
+ "headCategory": "$headCategory",
502
+ "year": "$year",
503
+ "month": "$month",
504
+ },
505
+ "total": {"$sum": "$amount"},
506
+ }
507
+ },
508
+ {
509
+ "$lookup": {
510
+ "from": "headcategories",
511
+ "localField": "_id.headCategory",
512
+ "foreignField": "_id",
513
+ "as": "headCategoryDoc",
514
+ }
515
+ },
516
+ {"$unwind": "$headCategoryDoc"},
517
+ {"$sort": {"_id.headCategory": 1, "_id.year": 1, "_id.month": 1}},
518
+ ]
519
+
520
+ results = list(mongo.transactions.aggregate(pipeline))
521
+
522
+ grouped: Dict[ObjectId, Dict[str, List[MonthlyExpense]]] = defaultdict(lambda: {"history": []})
523
+
524
+ for item in results:
525
+ head_category_id: ObjectId = item["_id"]["headCategory"]
526
+ category_record = grouped[head_category_id]
527
+ category_record["title"] = item["headCategoryDoc"].get("title", "Unknown")
528
+ category_record["history"].append(
529
+ MonthlyExpense(
530
+ year=item["_id"]["year"],
531
+ month=item["_id"]["month"],
532
+ total=float(item["total"]),
533
+ )
534
+ )
535
+
536
+ categories: List[CategoryPrediction] = []
537
+ for head_category_id, record in grouped.items():
538
+ history = sorted(record["history"], key=lambda doc: (doc.year, doc.month))
539
+ predicted_total = _predict_next_month(history)
540
+
541
+ categories.append(
542
+ CategoryPrediction(
543
+ headCategoryId=str(head_category_id),
544
+ title=record.get("title", "Unknown"),
545
+ history=history,
546
+ predictionMonth=MonthlyExpense(
547
+ year=prediction_month.year,
548
+ month=prediction_month.month,
549
+ total=predicted_total,
550
+ ),
551
+ )
552
+ )
553
+
554
+ return PredictionResponse(userId=user_id, categories=categories)
555
+
556
+
557
+ # Optional: health check
558
+ @app.get("/health")
559
+ def health():
560
+ return {"status": "healthy"}
561
+
562
+
563
+
564
+
565
+
566
+
567
+
568
+
569
+
570
+
571
+
572
+ # import calendar
573
+ # import os
574
+ # from collections import defaultdict
575
+ # from datetime import datetime, timezone
576
+ # from typing import Dict, List
577
+
578
+ # from bson import ObjectId
579
+ # from dotenv import load_dotenv
580
+ # from fastapi import FastAPI, HTTPException
581
+ # from pydantic import BaseModel, Field
582
+ # from pymongo import MongoClient
583
+ # from pymongo.collection import Collection
584
+
585
+ # load_dotenv()
586
+
587
+ # app = FastAPI(title="Expense Prediction API", version="1.0.0")
588
+
589
+
590
+ # class MonthlyExpense(BaseModel):
591
+ # year: int
592
+ # month: int
593
+ # total: float = Field(..., description="Total expenses recorded for the month")
594
+
595
+
596
+ # class CategoryPrediction(BaseModel):
597
+ # headCategoryId: str
598
+ # title: str
599
+ # history: List[MonthlyExpense]
600
+ # predictionMonth: MonthlyExpense
601
+
602
+
603
+ # class PredictionResponse(BaseModel):
604
+ # userId: str
605
+ # categories: List[CategoryPrediction]
606
+
607
+
608
+ # class MongoConnection:
609
+ # def __init__(self) -> None:
610
+ # mongo_uri = os.getenv("MONGO_URI")
611
+ # if not mongo_uri:
612
+ # raise RuntimeError("MONGO_URI is not configured in the environment")
613
+
614
+ # self._client = MongoClient(mongo_uri, tz_aware=True)
615
+ # self._database = self._client.get_default_database()
616
+ # self.transactions: Collection = self._database["transactions"]
617
+ # self.headcategories: Collection = self._database["headcategories"]
618
+
619
+
620
+ # mongo = MongoConnection()
621
+
622
+
623
+ # def _first_day_of_month(dt: datetime) -> datetime:
624
+ # return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
625
+
626
+
627
+ # def _shift_months(dt: datetime, months: int) -> datetime:
628
+ # month_index = dt.month - 1 + months
629
+ # year = dt.year + month_index // 12
630
+ # month = month_index % 12 + 1
631
+ # last_day = calendar.monthrange(year, month)[1]
632
+ # day = min(dt.day, last_day)
633
+ # return dt.replace(year=year, month=month, day=day)
634
+
635
+
636
+ # # -----------------------------------------------------------
637
+ # # NEW: Weighted Moving Average-based prediction function
638
+ # # -----------------------------------------------------------
639
+
640
+ # def _predict_next_month(history: List[MonthlyExpense]) -> float:
641
+ # """Predict next month's expense using Weighted Moving Average (WMA)."""
642
+ # totals = [h.total for h in history]
643
+
644
+ # # Only one month → Just repeat last month
645
+ # if len(totals) == 1:
646
+ # return round(totals[-1], 2)
647
+
648
+ # # Two months → Slight smoothing
649
+ # if len(totals) == 2:
650
+ # last, prev = totals[-1], totals[-2]
651
+ # prediction = last * 0.7 + prev * 0.3
652
+ # return round(prediction, 2)
653
+
654
+ # # Three or more months → Use 3-month WMA (0.5, 0.3, 0.2)
655
+ # last3 = totals[-3:]
656
+ # weights = [0.2, 0.3, 0.5] # oldest → newest
657
+ # prediction = sum(v * w for v, w in zip(last3, weights))
658
+
659
+ # return round(prediction, 2)
660
+
661
+
662
+ # # -----------------------------------------------------------
663
+ # # EXPENSE PREDICTION ENDPOINT
664
+ # # -----------------------------------------------------------
665
+
666
+ # @app.get("/users/{user_id}/expense-prediction", response_model=PredictionResponse)
667
+ # def predict_expense(user_id: str) -> PredictionResponse:
668
+ # try:
669
+ # user_object_id = ObjectId(user_id)
670
+ # except Exception as exc:
671
+ # raise HTTPException(status_code=400, detail="Invalid user id") from exc
672
+
673
+ # now = datetime.now(timezone.utc)
674
+ # start_period = _shift_months(_first_day_of_month(now), -2)
675
+ # prediction_month = _shift_months(_first_day_of_month(now), 1)
676
+
677
+ # pipeline = [
678
+ # {
679
+ # "$match": {
680
+ # "user": user_object_id,
681
+ # "type": "EXPENSE",
682
+ # "headCategory": {"$ne": None},
683
+ # "date": {"$gte": start_period},
684
+ # }
685
+ # },
686
+ # {
687
+ # "$project": {
688
+ # "amount": 1,
689
+ # "headCategory": 1,
690
+ # "year": {"$year": "$date"},
691
+ # "month": {"$month": "$date"},
692
+ # }
693
+ # },
694
+ # {
695
+ # "$group": {
696
+ # "_id": {
697
+ # "headCategory": "$headCategory",
698
+ # "year": "$year",
699
+ # "month": "$month",
700
+ # },
701
+ # "total": {"$sum": "$amount"},
702
+ # }
703
+ # },
704
+ # {
705
+ # "$lookup": {
706
+ # "from": "headcategories",
707
+ # "localField": "_id.headCategory",
708
+ # "foreignField": "_id",
709
+ # "as": "headCategoryDoc",
710
+ # }
711
+ # },
712
+ # {"$unwind": "$headCategoryDoc"},
713
+ # {"$sort": {"_id.headCategory": 1, "_id.year": 1, "_id.month": 1}},
714
+ # ]
715
+
716
+ # results = list(mongo.transactions.aggregate(pipeline))
717
+
718
+ # grouped: Dict[ObjectId, Dict[str, List[MonthlyExpense]]] = defaultdict(
719
+ # lambda: {"history": []}
720
+ # )
721
+
722
+ # for item in results:
723
+ # head_category_id: ObjectId = item["_id"]["headCategory"]
724
+ # category_record = grouped[head_category_id]
725
+ # category_record["title"] = item["headCategoryDoc"].get("title", "Unknown")
726
+ # category_record["history"].append(
727
+ # MonthlyExpense(
728
+ # year=item["_id"]["year"],
729
+ # month=item["_id"]["month"],
730
+ # total=float(item["total"]),
731
+ # )
732
+ # )
733
+
734
+ # categories: List[CategoryPrediction] = []
735
+ # for head_category_id, record in grouped.items():
736
+ # history = sorted(record["history"], key=lambda doc: (doc.year, doc.month))
737
+ # predicted_total = _predict_next_month(history)
738
+
739
+ # categories.append(
740
+ # CategoryPrediction(
741
+ # headCategoryId=str(head_category_id),
742
+ # title=record.get("title", "Unknown"),
743
+ # history=history,
744
+ # predictionMonth=MonthlyExpense(
745
+ # year=prediction_month.year,
746
+ # month=prediction_month.month,
747
+ # total=predicted_total,
748
+ # ),
749
+ # )
750
+ # )
751
+
752
+ # return PredictionResponse(userId=user_id, categories=categories)