File size: 28,170 Bytes
5216e16
 
 
 
 
 
 
1fd3202
5216e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a48805
 
 
 
 
5216e16
 
 
 
 
 
 
 
 
 
 
1fd3202
5216e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fd3202
 
 
 
 
 
 
 
 
 
 
 
 
1b441ae
1fd3202
 
 
 
 
 
 
 
 
 
 
5216e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fd3202
 
 
 
5216e16
 
1fd3202
 
 
 
 
 
 
5216e16
1fd3202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5216e16
1fd3202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5216e16
1fd3202
 
 
 
 
 
 
 
 
 
 
 
 
 
5216e16
 
1fd3202
 
 
 
 
 
 
 
 
5216e16
 
 
 
1fd3202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5216e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
# app.py
import calendar
import math
import os
from collections import defaultdict
from datetime import datetime, timezone
from typing import Dict, List, Optional, Tuple
from time import perf_counter
from bson import ObjectId
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from pymongo import MongoClient
from pymongo.collection import Collection

load_dotenv()

app = FastAPI(title="Expense Prediction API", version="1.0.0")

# ---------- Configurable constants ----------
MAX_HISTORY_MONTHS = int(os.getenv("MAX_HISTORY_MONTHS", "36"))  # months to fetch for detection/tuning
SEASONALITY_PERIOD = int(os.getenv("SEASONALITY_PERIOD", "12"))  # monthly seasonality (12 months)
SEASONALITY_AMPLITUDE_THRESHOLD = float(os.getenv("SEASONALITY_AMPLITUDE_THRESHOLD", "0.18"))
# grid-search limits (keeps tuning light)
ALPHA_GRID = [0.3, 0.5, 0.7]
BETA_GRID = [0.1, 0.3, 0.5]
GAMMA_GRID = [0.1, 0.3, 0.5]
MAX_GRID_SEARCH_COMBINATIONS = 30  # safety cap
# ------------------------------------------------

class MonthlyExpense(BaseModel):
    year: int
    month: int
    total: float = Field(..., description="Total expenses recorded for the month")

class CategoryPrediction(BaseModel):
    headCategoryId: str
    title: str
    history: List[MonthlyExpense]
    predictionMonth: MonthlyExpense


class PredictionResponse(BaseModel):
    userId: str
    categories: List[CategoryPrediction]
    
class APIResponse(BaseModel):
    status: str
    message: str
    data: Optional[PredictionResponse] = None

class MongoConnection:
    def __init__(self) -> None:
        mongo_uri = os.getenv("MONGO_URI")
        if not mongo_uri:
            raise RuntimeError("MONGO_URI is not configured in the environment")

        self._client = MongoClient(mongo_uri, tz_aware=True)
        self._database = self._client.get_default_database()
        self.transactions: Collection = self._database["transactions"]
        self.headcategories: Collection = self._database["headcategories"]
        self.api_logs: Collection = self._database["api_logs"]


mongo = MongoConnection()

# ----------------- Date helpers -----------------
def _first_day_of_month(dt: datetime) -> datetime:
    return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)


def _shift_months(dt: datetime, months: int) -> datetime:
    month_index = dt.month - 1 + months
    year = dt.year + month_index // 12
    month = month_index % 12 + 1
    last_day = calendar.monthrange(year, month)[1]
    day = min(dt.day, last_day)
    return dt.replace(year=year, month=month, day=day)


def month_to_index(year: int, month: int) -> int:
    return year * 12 + (month - 1)


def index_to_month(idx: int) -> Tuple[int, int]:
    year = idx // 12
    month = (idx % 12) + 1
    return year, month

def log_api_event(
    name: str,
    status: str,
    response_time: float,
    user_id: Optional[str] = None,
    error_message: Optional[str] = None,
):
    payload = {
        "name": name,
        "status": status,
        "response_time": round(response_time, 3),
        "user_id": user_id or "anonymous",
        "date": datetime.now(timezone.utc),
    }

    if error_message:
        payload["error_message"] = error_message

    try:
        mongo.api_logs.insert_one(payload)
    except Exception:
        # never crash API because of logging
        pass

# ------------------------------------------------

# ----------------- Time series utilities -----------------
def build_continuous_series(history: List[MonthlyExpense]) -> Tuple[List[float], List[Tuple[int, int]]]:
    """
    Given sparse monthly history items (year, month, total), build a continuous series
    covering from earliest to latest month in history. Missing months are represented by None.
    Returns (values_list_with_none, list_of_(year,month)_corresponding).
    """
    if not history:
        return [], []

    # sort history
    history_sorted = sorted(history, key=lambda h: (h.year, h.month))
    start_idx = month_to_index(history_sorted[0].year, history_sorted[0].month)
    end_idx = month_to_index(history_sorted[-1].year, history_sorted[-1].month)
    length = end_idx - start_idx + 1

    idx_to_val = {}
    for h in history_sorted:
        idx = month_to_index(h.year, h.month)
        idx_to_val[idx] = h.total

    series = []
    months = []
    for i in range(start_idx, end_idx + 1):
        months.append(index_to_month(i))
        series.append(idx_to_val.get(i, None))

    return series, months


def impute_missing(series: List[Optional[float]]) -> List[float]:
    """
    Fill missing values (None) by linear interpolation. If leading/trailing Nones remain,
    forward/backfill with nearest value or 0 if no data.
    """
    n = len(series)
    if n == 0:
        return []

    arr = [None if v is None else float(v) for v in series]

    # collect indices of non-None
    known = [i for i, v in enumerate(arr) if v is not None]

    if not known:
        # all missing -> return zeros
        return [0.0] * n

    # linear interpolation between known points
    for i in range(len(known) - 1):
        a = known[i]
        b = known[i + 1]
        va = arr[a]
        vb = arr[b]
        step = (vb - va) / (b - a)
        for j in range(a + 1, b):
            arr[j] = va + step * (j - a)

    # fill leading
    first = known[0]
    for i in range(0, first):
        arr[i] = arr[first]

    # fill trailing
    last = known[-1]
    for i in range(last + 1, n):
        arr[i] = arr[last]

    return [float(x) for x in arr]


def seasonal_strength(series: List[float], period: int = SEASONALITY_PERIOD) -> float:
    """
    Estimate seasonality strength for monthly data.
    Returns amplitude_ratio = (max_month_mean - min_month_mean) / overall_mean
    Higher value => stronger seasonality.
    Requires at least 2 * period data points for a reliable estimate.
    """
    n = len(series)
    if n < 2 * period:
        return 0.0

    # compute month-of-year means
    month_buckets = [[] for _ in range(period)]
    for idx, val in enumerate(series):
        month = idx % period
        month_buckets[month].append(val)

    month_means = [ (sum(b)/len(b)) if b else 0.0 for b in month_buckets ]
    overall_mean = sum(series) / len(series) if series else 0.0
    if overall_mean == 0:
        return 0.0
    amplitude = max(month_means) - min(month_means)
    return amplitude / overall_mean


# ----------------- Forecasting algorithms -----------------
def holt_double_forecast(series: List[float], alpha: float, beta: float, n_forecast: int = 1) -> List[float]:
    """
    Holt's linear method (double exponential smoothing).
    Returns list of length n_forecast (forecast ahead).
    """
    n = len(series)
    if n == 0:
        return [0.0] * n_forecast
    if n == 1:
        return [series[-1]] * n_forecast

    level = series[0]
    trend = series[1] - series[0]

    for t in range(1, n):
        value = series[t]
        prev_level = level
        level = alpha * value + (1 - alpha) * (level + trend)
        trend = beta * (level - prev_level) + (1 - beta) * trend

    # forecast h steps ahead
    forecasts = [level + (i + 1) * trend for i in range(n_forecast)]
    return [max(0.0, f) for f in forecasts]


def holt_winters_additive(series: List[float], season_length: int, alpha: float, beta: float, gamma: float, n_forecast: int = 1) -> List[float]:
    """
    Additive Holt-Winters seasonal method.
    series: list of floats (no missing) where season_length is known (e.g., 12)
    """
    n = len(series)
    if n == 0:
        return [0.0] * n_forecast
    if n < season_length * 2:
        # not enough data to initialize seasonals reliably -> fallback to holt_double
        return holt_double_forecast(series, alpha, beta, n_forecast)

    # initialize level, trend, seasonals
    seasonals = _initial_seasonal_components(series, season_length)
    level = sum(series[:season_length]) / season_length
    trend = (sum(series[season_length:2*season_length]) - sum(series[:season_length])) / (season_length * season_length)

    result = []
    for i in range(n + n_forecast):
        if i < n:
            val = series[i]
            last_level = level
            level = alpha * (val - seasonals[i % season_length]) + (1 - alpha) * (level + trend)
            trend = beta * (level - last_level) + (1 - beta) * trend
            seasonals[i % season_length] = gamma * (val - level) + (1 - gamma) * seasonals[i % season_length]
            # in-sample prediction (not used)
        else:
            # forecast
            m = i - n + 1
            forecast = level + m * trend + seasonals[i % season_length]
            result.append(max(0.0, forecast))

    # ensure length matches n_forecast
    return result[:n_forecast]


def _initial_seasonal_components(series: List[float], season_length: int) -> List[float]:
    """
    Initialize seasonality components by averaging.
    """
    seasonals = [0.0] * season_length
    n_seasons = len(series) // season_length
    if n_seasons == 0:
        return seasonals
    season_averages = []
    for j in range(n_seasons):
        start = j * season_length
        season_avg = sum(series[start:start + season_length]) / season_length
        season_averages.append(season_avg)
    for i in range(season_length):
        s = 0.0
        for j in range(n_seasons):
            s += series[j * season_length + i] - season_averages[j]
        seasonals[i] = s / n_seasons
    return seasonals

# ----------------- Dynamic WMA -----------------
def dynamic_wma(series: List[float], max_len: int = 6) -> float:
    """
    Compute a dynamic WMA using up to max_len most recent months.
    The weights adapt based on volatility: higher volatility -> smoother (older months get more weight).
    """
    n = len(series)
    if n == 0:
        return 0.0
    take = min(n, max_len)
    recent = series[-take:]
    # compute month-to-month relative changes
    if len(recent) >= 2:
        changes = [abs(recent[i] - recent[i - 1]) for i in range(1, len(recent))]
        vol = sum(changes) / len(changes) if changes else 0.0
    else:
        vol = 0.0

    # base weights favor recent months
    base_weights = [ (i + 1) for i in range(take) ]  # 1..take
    base_weights = list(reversed(base_weights))  # newest highest
    total = sum(base_weights)
    base_weights = [w/total for w in base_weights]

    # adaptation factor: more vol -> flatten weights
    # vol_ratio normalized roughly w.r.t average magnitude
    avg = sum(recent) / len(recent) if recent else 1.0
    vol_ratio = (vol / avg) if avg else 0.0
    # clamp vol_ratio
    vol_ratio = max(0.0, min(vol_ratio, 1.0))

    # blend between base_weights and equal weights
    equal_weights = [1.0 / take] * take
    blend = min(0.7, vol_ratio)  # limit blend to avoid extreme flattening
    weights = [(1 - blend) * bw + blend * ew for bw, ew in zip(base_weights, equal_weights)]
    # compute prediction
    prediction = sum(w * v for w, v in zip(weights, reversed(recent)))  # reversed so weights map newest->oldest
    return max(0.0, prediction)

# ----------------- Parameter tuning (lightweight) -----------------
def walk_forward_cv_mse(series: List[float], forecast_func, params: dict, min_train_size: int = 6) -> float:
    """
    Perform walk-forward validation computing MSE. forecast_func must accept (train_series, params) and return a single-step forecast.
    """
    n = len(series)
    if n < min_train_size + 1:
        # not enough data to validate -> return large error so tuner avoids complex models
        return float("inf")

    errors = []
    # iterate rolling window
    for split in range(min_train_size, n):
        train = series[:split]
        actual = series[split]
        try:
            pred = forecast_func(train, params)
        except Exception:
            return float("inf")
        if pred is None:
            return float("inf")
        errors.append((pred - actual) ** 2)
    return sum(errors) / len(errors) if errors else float("inf")


def forecast_wrapper_holt(train: List[float], params: dict) -> float:
    alpha = params.get("alpha", 0.5)
    beta = params.get("beta", 0.3)
    return holt_double_forecast(train, alpha, beta, n_forecast=1)[0]


def forecast_wrapper_hw(train: List[float], params: dict) -> float:
    alpha = params.get("alpha", 0.5)
    beta = params.get("beta", 0.3)
    gamma = params.get("gamma", 0.2)
    season_length = params.get("season_length", SEASONALITY_PERIOD)
    return holt_winters_additive(train, season_length, alpha, beta, gamma, n_forecast=1)[0]


def tune_parameters(series: List[float], seasonal: bool, season_length: int = SEASONALITY_PERIOD) -> dict:
    """
    Lightweight grid search for (alpha, beta, gamma) returning best params.
    Uses walk-forward CV to score parameter combinations.
    """
    best = None
    best_score = float("inf")
    combos_tested = 0

    if seasonal:
        grid = []
        for a in ALPHA_GRID:
            for b in BETA_GRID:
                for g in GAMMA_GRID:
                    grid.append({"alpha": a, "beta": b, "gamma": g, "season_length": season_length})
    else:
        grid = [{"alpha": a, "beta": b} for a in ALPHA_GRID for b in BETA_GRID]

    # cap combos
    if len(grid) > MAX_GRID_SEARCH_COMBINATIONS:
        grid = grid[:MAX_GRID_SEARCH_COMBINATIONS]

    for params in grid:
        combos_tested += 1
        if seasonal:
            score = walk_forward_cv_mse(series, forecast_wrapper_hw, params, min_train_size=max(6, season_length))
        else:
            score = walk_forward_cv_mse(series, forecast_wrapper_holt, params, min_train_size=6)
        if score < best_score:
            best_score = score
            best = params

    if best is None:
        # fallback default
        if seasonal:
            return {"alpha": 0.5, "beta": 0.3, "gamma": 0.2, "season_length": season_length}
        else:
            return {"alpha": 0.5, "beta": 0.3}

    return best

# ----------------- Top-level predictor combining everything -----------------
def _predict_next_month(history: List[MonthlyExpense]) -> float:
    """
    Comprehensive predictor:
    - builds continuous series and imputes missing months
    - auto-detects seasonality
    - tunes parameters (lightweight) per series
    - uses Holt-Winters if seasonal, else Holt
    - fallback to dynamic WMA for very short/noisy series
    """
    if not history:
        return 0.0

    # limit history length to MAX_HISTORY_MONTHS (use most recent months)
    history_sorted = sorted(history, key=lambda h: (h.year, h.month))
    if len(history_sorted) > MAX_HISTORY_MONTHS:
        history_sorted = history_sorted[-MAX_HISTORY_MONTHS:]

    # Build continuous series (may contain Nones for missing months)
    series_with_none, months = build_continuous_series(history_sorted)
    series = impute_missing(series_with_none)

    # if after imputation all zeros, return 0
    if all(v == 0.0 for v in series):
        return 0.0

    n = len(series)

    # If very short history (<=2) use simple rules / dynamic WMA
    if n <= 2:
        return round(dynamic_wma(series, max_len=2), 2)

    # Seasonality detection: needs at least 2 * season_length samples for reliability
    season_strength = seasonal_strength(series, period=SEASONALITY_PERIOD)
    is_seasonal = season_strength >= SEASONALITY_AMPLITUDE_THRESHOLD and n >= 2 * SEASONALITY_PERIOD

    # If not much data but still some seasonality signal present and we have at least season_length points,
    # we can still attempt seasonal HW but with care.
    season_length_used = SEASONALITY_PERIOD if is_seasonal else None

    # Tuning: per-series personalized coefficients
    try:
        tuned = tune_parameters(series, seasonal=is_seasonal, season_length=season_length_used or SEASONALITY_PERIOD)
    except Exception:
        tuned = None

    # If tuning failed or not enough data, fallback defaults
    if tuned is None:
        if is_seasonal:
            tuned = {"alpha": 0.5, "beta": 0.3, "gamma": 0.2, "season_length": SEASONALITY_PERIOD}
        else:
            tuned = {"alpha": 0.5, "beta": 0.3}

    # Edge case: if the series is extremely volatile compared to mean, prefer dynamic WMA (more robust)
    mean_val = sum(series) / len(series) if series else 0.0
    diffs = [abs(series[i] - series[i - 1]) for i in range(1, len(series))] if len(series) >= 2 else [0.0]
    avg_diff = sum(diffs) / len(diffs) if diffs else 0.0
    volatility_ratio = (avg_diff / mean_val) if mean_val else 0.0

    if volatility_ratio > 1.0 and n < 6:
        # extremely volatile and short history -> WMA is safer
        pred = dynamic_wma(series, max_len=min(6, n))
        return round(pred, 2)

    # Choose model
    if is_seasonal:
        alpha = tuned.get("alpha", 0.5)
        beta = tuned.get("beta", 0.3)
        gamma = tuned.get("gamma", 0.2)
        season_length = tuned.get("season_length", SEASONALITY_PERIOD)
        pred = holt_winters_additive(series, season_length, alpha, beta, gamma, n_forecast=1)[0]
    else:
        alpha = tuned.get("alpha", 0.5)
        beta = tuned.get("beta", 0.3)
        pred = holt_double_forecast(series, alpha, beta, n_forecast=1)[0]

    # final safety clamps
    if math.isnan(pred) or pred is None or pred < 0:
        # fallback to recent avg
        pred = sum(series[-3:]) / min(3, len(series))

    return round(float(pred), 2)


# ----------------- API endpoint -----------------
@app.get("/users/{user_id}/expense-prediction",response_model=APIResponse,)
def predict_expense(user_id: str):
    start_time = perf_counter()

    try:
        user_object_id = ObjectId(user_id)
    except Exception:
        log_api_event(
            name="Expense Prediction",
            status="failed",
            response_time=0,
            user_id=user_id,
            error_message="Invalid user id",
        )
        raise HTTPException(status_code=400, detail="Invalid user id")

    try:
        now = datetime.now(timezone.utc)
        start_period = _shift_months(_first_day_of_month(now), -MAX_HISTORY_MONTHS + 1)
        prediction_month = _shift_months(_first_day_of_month(now), 1)

        pipeline = [
            {
                "$match": {
                    "user": user_object_id,
                    "type": "EXPENSE",
                    "headCategory": {"$ne": None},
                    "date": {"$gte": start_period},
                }
            },
            {
                "$project": {
                    "amount": 1,
                    "headCategory": 1,
                    "year": {"$year": "$date"},
                    "month": {"$month": "$date"},
                }
            },
            {
                "$group": {
                    "_id": {
                        "headCategory": "$headCategory",
                        "year": "$year",
                        "month": "$month",
                    },
                    "total": {"$sum": "$amount"},
                }
            },
            {
                "$lookup": {
                    "from": "headcategories",
                    "localField": "_id.headCategory",
                    "foreignField": "_id",
                    "as": "headCategoryDoc",
                }
            },
            {"$unwind": "$headCategoryDoc"},
            {"$sort": {"_id.headCategory": 1, "_id.year": 1, "_id.month": 1}},
        ]

        results = list(mongo.transactions.aggregate(pipeline))

        grouped: Dict[ObjectId, Dict[str, List[MonthlyExpense]]] = defaultdict(lambda: {"history": []})

        for item in results:
            head_category_id: ObjectId = item["_id"]["headCategory"]
            category_record = grouped[head_category_id]
            category_record["title"] = item["headCategoryDoc"].get("title", "Unknown")
            category_record["history"].append(
                MonthlyExpense(
                    year=item["_id"]["year"],
                    month=item["_id"]["month"],
                    total=float(item["total"]),
                )
            )

        categories: List[CategoryPrediction] = []
        for head_category_id, record in grouped.items():
            history = sorted(record["history"], key=lambda doc: (doc.year, doc.month))
            predicted_total = _predict_next_month(history)

            categories.append(
                CategoryPrediction(
                    headCategoryId=str(head_category_id),
                    title=record.get("title", "Unknown"),
                    history=history,
                    predictionMonth=MonthlyExpense(
                        year=prediction_month.year,
                        month=prediction_month.month,
                        total=predicted_total,
                    ),
                )
            )

        response_data = PredictionResponse(userId=user_id, categories=categories)

        log_api_event(
            name="Expense Prediction",
            status="success",
            response_time=perf_counter() - start_time,
            user_id=user_id,
        )

        return APIResponse(
            status="success",
            message="Expense prediction generated successfully",
            data=response_data,
        )

    except Exception as exc:
        log_api_event(
            name="Expense Prediction",
            status="failed",
            response_time=perf_counter() - start_time,
            user_id=user_id,
            error_message=str(exc),
        )
        raise HTTPException(status_code=500, detail="Internal server error")


@app.get("/health")
def health():
    try:
        mongo._client.admin.command("ping")
        return {
            "status": "ok",
            "message": "Service is healthy",
            "timestamp": datetime.now(timezone.utc),
        }
    except Exception as exc:
        raise HTTPException(
            status_code=503,
            detail={
                "status": "down",
                "message": "Database connectivity failed",
                "error": str(exc),
            },
        )












# import calendar
# import os
# from collections import defaultdict
# from datetime import datetime, timezone
# from typing import Dict, List

# from bson import ObjectId
# from dotenv import load_dotenv
# from fastapi import FastAPI, HTTPException
# from pydantic import BaseModel, Field
# from pymongo import MongoClient
# from pymongo.collection import Collection

# load_dotenv()

# app = FastAPI(title="Expense Prediction API", version="1.0.0")


# class MonthlyExpense(BaseModel):
#     year: int
#     month: int
#     total: float = Field(..., description="Total expenses recorded for the month")


# class CategoryPrediction(BaseModel):
#     headCategoryId: str
#     title: str
#     history: List[MonthlyExpense]
#     predictionMonth: MonthlyExpense


# class PredictionResponse(BaseModel):
#     userId: str
#     categories: List[CategoryPrediction]


# class MongoConnection:
#     def __init__(self) -> None:
#         mongo_uri = os.getenv("MONGO_URI")
#         if not mongo_uri:
#             raise RuntimeError("MONGO_URI is not configured in the environment")

#         self._client = MongoClient(mongo_uri, tz_aware=True)
#         self._database = self._client.get_default_database()
#         self.transactions: Collection = self._database["transactions"]
#         self.headcategories: Collection = self._database["headcategories"]


# mongo = MongoConnection()


# def _first_day_of_month(dt: datetime) -> datetime:
#     return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)


# def _shift_months(dt: datetime, months: int) -> datetime:
#     month_index = dt.month - 1 + months
#     year = dt.year + month_index // 12
#     month = month_index % 12 + 1
#     last_day = calendar.monthrange(year, month)[1]
#     day = min(dt.day, last_day)
#     return dt.replace(year=year, month=month, day=day)


# # -----------------------------------------------------------
# # NEW: Weighted Moving Average-based prediction function
# # -----------------------------------------------------------

# def _predict_next_month(history: List[MonthlyExpense]) -> float:
#     """Predict next month's expense using Weighted Moving Average (WMA)."""
#     totals = [h.total for h in history]

#     # Only one month β†’ Just repeat last month
#     if len(totals) == 1:
#         return round(totals[-1], 2)

#     # Two months β†’ Slight smoothing
#     if len(totals) == 2:
#         last, prev = totals[-1], totals[-2]
#         prediction = last * 0.7 + prev * 0.3
#         return round(prediction, 2)

#     # Three or more months β†’ Use 3-month WMA (0.5, 0.3, 0.2)
#     last3 = totals[-3:]
#     weights = [0.2, 0.3, 0.5]  # oldest β†’ newest
#     prediction = sum(v * w for v, w in zip(last3, weights))

#     return round(prediction, 2)


# # -----------------------------------------------------------
# # EXPENSE PREDICTION ENDPOINT
# # -----------------------------------------------------------

# @app.get("/users/{user_id}/expense-prediction", response_model=PredictionResponse)
# def predict_expense(user_id: str) -> PredictionResponse:
#     try:
#         user_object_id = ObjectId(user_id)
#     except Exception as exc:
#         raise HTTPException(status_code=400, detail="Invalid user id") from exc

#     now = datetime.now(timezone.utc)
#     start_period = _shift_months(_first_day_of_month(now), -2)
#     prediction_month = _shift_months(_first_day_of_month(now), 1)

#     pipeline = [
#         {
#             "$match": {
#                 "user": user_object_id,
#                 "type": "EXPENSE",
#                 "headCategory": {"$ne": None},
#                 "date": {"$gte": start_period},
#             }
#         },
#         {
#             "$project": {
#                 "amount": 1,
#                 "headCategory": 1,
#                 "year": {"$year": "$date"},
#                 "month": {"$month": "$date"},
#             }
#         },
#         {
#             "$group": {
#                 "_id": {
#                     "headCategory": "$headCategory",
#                     "year": "$year",
#                     "month": "$month",
#                 },
#                 "total": {"$sum": "$amount"},
#             }
#         },
#         {
#             "$lookup": {
#                 "from": "headcategories",
#                 "localField": "_id.headCategory",
#                 "foreignField": "_id",
#                 "as": "headCategoryDoc",
#             }
#         },
#         {"$unwind": "$headCategoryDoc"},
#         {"$sort": {"_id.headCategory": 1, "_id.year": 1, "_id.month": 1}},
#     ]

#     results = list(mongo.transactions.aggregate(pipeline))

#     grouped: Dict[ObjectId, Dict[str, List[MonthlyExpense]]] = defaultdict(
#         lambda: {"history": []}
#     )

#     for item in results:
#         head_category_id: ObjectId = item["_id"]["headCategory"]
#         category_record = grouped[head_category_id]
#         category_record["title"] = item["headCategoryDoc"].get("title", "Unknown")
#         category_record["history"].append(
#             MonthlyExpense(
#                 year=item["_id"]["year"],
#                 month=item["_id"]["month"],
#                 total=float(item["total"]),
#             )
#         )

#     categories: List[CategoryPrediction] = []
#     for head_category_id, record in grouped.items():
#         history = sorted(record["history"], key=lambda doc: (doc.year, doc.month))
#         predicted_total = _predict_next_month(history)

#         categories.append(
#             CategoryPrediction(
#                 headCategoryId=str(head_category_id),
#                 title=record.get("title", "Unknown"),
#                 history=history,
#                 predictionMonth=MonthlyExpense(
#                     year=prediction_month.year,
#                     month=prediction_month.month,
#                     total=predicted_total,
#                 ),
#             )
#         )

#     return PredictionResponse(userId=user_id, categories=categories)