Corin1998 commited on
Commit
45e6e77
·
verified ·
1 Parent(s): b5725bf

Create app/bandit.py

Browse files
Files changed (1) hide show
  1. app/bandit.py +64 -0
app/bandit.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import random
3
+ from typing import Dict, Any, Tuple, Callable, Optional
4
+ from . import storage
5
+
6
+ class ThompsonBandit:
7
+ """
8
+ CTR(クリック)とCVR(コンバージョン)の二段ベータ。
9
+ 目的関数: E[value] = p_click * p_conv * value_per_conversion
10
+ 季節性は CTR 側の仮想カウントで補正。
11
+ """
12
+ def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
13
+ self.campaign_id = campaign_id
14
+ self.seasonality_boost = seasonality_boost
15
+
16
+ def sample_arm(
17
+ self,
18
+ context: Dict[str, Any],
19
+ seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK
20
+ ) -> Tuple[Optional[str], float]:
21
+ metrics = storage.get_metrics(self.campaign_id)
22
+ if not metrics:
23
+ return None, -1.0
24
+
25
+ vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
26
+
27
+ # 季節性スコア s ∈ (0, 1)
28
+ s = 0.5
29
+ if seasonal_fn is not None:
30
+ try:
31
+ s = float(seasonal_fn(context))
32
+ s = max(0.01, min(0.99, s))
33
+ except Exception:
34
+ s = 0.5
35
+
36
+ best_score, best_variant = -1.0, None
37
+ for row in metrics:
38
+ ac, bc = float(row["alpha_click"]), float(row["beta_click"])
39
+ av, bv = float(row["alpha_conv"]), float(row["beta_conv"])
40
+
41
+ # 季節性でクリック側の事前分布を微調整
42
+ ac_eff = ac + self.seasonality_boost * s
43
+ bc_eff = bc + self.seasonality_boost * (1.0 - s)
44
+
45
+ pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
46
+ pv = random.betavariate(max(1e-6, av), max(1e-6, bv))
47
+ score = pc * pv * vpc
48
+ if score > best_score:
49
+ best_score, best_variant = score, row["variant_id"]
50
+ return best_variant, best_score
51
+
52
+ @staticmethod
53
+ def update_with_event(campaign_id: str, variant_id: str, event_type: str):
54
+ if event_type == "impression":
55
+ storage.update_metric(campaign_id, variant_id, "impressions", 1)
56
+ storage.update_metric(campaign_id, variant_id, "beta_click", 1)
57
+ elif event_type == "click":
58
+ storage.update_metric(campaign_id, variant_id, "clicks", 1)
59
+ storage.update_metric(campaign_id, variant_id, "alpha_click", 1)
60
+ elif event_type == "conversion":
61
+ storage.update_metric(campaign_id, variant_id, "conversions", 1)
62
+ storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
63
+ else:
64
+ raise ValueError("unknown event_type")