Spaces:
Sleeping
Sleeping
Create app/bandit.py
Browse files- app/bandit.py +64 -0
app/bandit.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import random
|
| 3 |
+
from typing import Dict, Any, Tuple, Callable, Optional
|
| 4 |
+
from . import storage
|
| 5 |
+
|
| 6 |
+
class ThompsonBandit:
|
| 7 |
+
"""
|
| 8 |
+
CTR(クリック)とCVR(コンバージョン)の二段ベータ。
|
| 9 |
+
目的関数: E[value] = p_click * p_conv * value_per_conversion
|
| 10 |
+
季節性は CTR 側の仮想カウントで補正。
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
|
| 13 |
+
self.campaign_id = campaign_id
|
| 14 |
+
self.seasonality_boost = seasonality_boost
|
| 15 |
+
|
| 16 |
+
def sample_arm(
|
| 17 |
+
self,
|
| 18 |
+
context: Dict[str, Any],
|
| 19 |
+
seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK
|
| 20 |
+
) -> Tuple[Optional[str], float]:
|
| 21 |
+
metrics = storage.get_metrics(self.campaign_id)
|
| 22 |
+
if not metrics:
|
| 23 |
+
return None, -1.0
|
| 24 |
+
|
| 25 |
+
vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
|
| 26 |
+
|
| 27 |
+
# 季節性スコア s ∈ (0, 1)
|
| 28 |
+
s = 0.5
|
| 29 |
+
if seasonal_fn is not None:
|
| 30 |
+
try:
|
| 31 |
+
s = float(seasonal_fn(context))
|
| 32 |
+
s = max(0.01, min(0.99, s))
|
| 33 |
+
except Exception:
|
| 34 |
+
s = 0.5
|
| 35 |
+
|
| 36 |
+
best_score, best_variant = -1.0, None
|
| 37 |
+
for row in metrics:
|
| 38 |
+
ac, bc = float(row["alpha_click"]), float(row["beta_click"])
|
| 39 |
+
av, bv = float(row["alpha_conv"]), float(row["beta_conv"])
|
| 40 |
+
|
| 41 |
+
# 季節性でクリック側の事前分布を微調整
|
| 42 |
+
ac_eff = ac + self.seasonality_boost * s
|
| 43 |
+
bc_eff = bc + self.seasonality_boost * (1.0 - s)
|
| 44 |
+
|
| 45 |
+
pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
|
| 46 |
+
pv = random.betavariate(max(1e-6, av), max(1e-6, bv))
|
| 47 |
+
score = pc * pv * vpc
|
| 48 |
+
if score > best_score:
|
| 49 |
+
best_score, best_variant = score, row["variant_id"]
|
| 50 |
+
return best_variant, best_score
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def update_with_event(campaign_id: str, variant_id: str, event_type: str):
|
| 54 |
+
if event_type == "impression":
|
| 55 |
+
storage.update_metric(campaign_id, variant_id, "impressions", 1)
|
| 56 |
+
storage.update_metric(campaign_id, variant_id, "beta_click", 1)
|
| 57 |
+
elif event_type == "click":
|
| 58 |
+
storage.update_metric(campaign_id, variant_id, "clicks", 1)
|
| 59 |
+
storage.update_metric(campaign_id, variant_id, "alpha_click", 1)
|
| 60 |
+
elif event_type == "conversion":
|
| 61 |
+
storage.update_metric(campaign_id, variant_id, "conversions", 1)
|
| 62 |
+
storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError("unknown event_type")
|