Corin1998 commited on
Commit
fd78077
·
verified ·
1 Parent(s): da82e08

Create forecast.py

Browse files
Files changed (1) hide show
  1. app/forecast.py +102 -0
app/forecast.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import pandas as pd
3
+ import numpy as np
4
+ from . import storage
5
+
6
+ # 可能なら Prophet / NeuralProphet を使用(無ければフォールバック)
7
+ try:
8
+ from prophet import Prophet
9
+ except Exception:
10
+ Prophet = None
11
+
12
+ try:
13
+ from neuralprophet import NeuralProphet
14
+ except Exception:
15
+ NeuralProphet = None
16
+
17
+
18
+ class SeasonalityModel:
19
+ def __init__(self, campaign_id: str):
20
+ self.campaign_id = campaign_id
21
+ self.model = None
22
+ self.model_type = "none"
23
+ self.global_mean = 0.05 # データが乏しいときの既定CTR
24
+
25
+ def fit(self):
26
+ # イベントから時系列(1時間粒度のCTR)を作る
27
+ with storage.get_conn() as con:
28
+ df = pd.read_sql_query(
29
+ "SELECT ts, event_type FROM events WHERE campaign_id=?",
30
+ con,
31
+ params=(self.campaign_id,),
32
+ )
33
+
34
+ if df.empty:
35
+ self.model_type = "none"
36
+ return
37
+
38
+ df["ts"] = pd.to_datetime(df["ts"], errors="coerce")
39
+ df = df.dropna(subset=["ts"])
40
+ df["hour"] = df["ts"].dt.floor("h")
41
+
42
+ agg = (
43
+ df.pivot_table(
44
+ index="hour", columns="event_type", values="ts", aggfunc="count"
45
+ )
46
+ .fillna(0)
47
+ )
48
+ if "impression" not in agg:
49
+ agg["impression"] = 0
50
+ if "click" not in agg:
51
+ agg["click"] = 0
52
+
53
+ ctr = np.where(
54
+ agg["impression"] > 0, agg["click"] / agg["impression"], np.nan
55
+ )
56
+ if np.all(np.isnan(ctr)):
57
+ self.model_type = "none"
58
+ return
59
+
60
+ self.global_mean = float(np.nanmean(ctr))
61
+
62
+ # Prophet / NeuralProphet の学習データ
63
+ ds = agg.index.to_series().reset_index(drop=True)
64
+ train = pd.DataFrame({"ds": ds, "y": pd.Series(ctr).fillna(self.global_mean).values})
65
+
66
+ try:
67
+ if Prophet is not None:
68
+ m = Prophet(weekly_seasonality=True, daily_seasonality=True)
69
+ m.fit(train)
70
+ self.model = m
71
+ self.model_type = "prophet"
72
+ elif NeuralProphet is not None:
73
+ m = NeuralProphet(weekly_seasonality=True, daily_seasonality=True)
74
+ m.fit(train, freq="H")
75
+ self.model = m
76
+ self.model_type = "neuralprophet"
77
+ else:
78
+ self.model_type = "none"
79
+ except Exception:
80
+ # 失敗時はフォールバック
81
+ self.model_type = "none"
82
+
83
+ def expected_ctr(self, context: dict) -> float:
84
+ hour = int(context.get("hour", 12))
85
+
86
+ # モデルが無い場合は簡易ヒューリスティック
87
+ if self.model_type in {None, "none"}:
88
+ base = self.global_mean
89
+ if 11 <= hour <= 13:
90
+ return min(0.99, base * 1.1)
91
+ if 20 <= hour <= 23:
92
+ return min(0.99, base * 1.15)
93
+ return max(0.01, base)
94
+
95
+ # モデルあり:当日・指定時間の1点予測
96
+ now_ds = pd.Timestamp.utcnow().floor("D") + pd.Timedelta(hours=hour)
97
+ if self.model_type == "prophet":
98
+ yhat = float(self.model.predict(pd.DataFrame({"ds": [now_ds]}))["yhat"].iloc[0])
99
+ else: # neuralprophet
100
+ yhat = float(self.model.predict(pd.DataFrame({"ds": [now_ds]}))["yhat1"].iloc[0])
101
+
102
+ return max(0.01, min(0.99, yhat))