Corin1998's picture
Update app/bandit.py
57af9d9 verified
from __future__ import annotations
import random
from typing import Dict, Any, Tuple, Callable, Optional
from . import storage
class ThompsonBandit:
"""
CTR(クリック)とCVR(コンバージョン)の二段ベータ。
目的関数: E[value] = p_click * p_conv * value_per_conversion
季節性は CTR 側の仮想カウントで補正。
"""
def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
self.campaign_id = campaign_id
self.seasonality_boost = seasonality_boost
def sample_arm(
self,
context: Dict[str, Any],
seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK
) -> Tuple[Optional[str], float]:
metrics = storage.get_metrics(self.campaign_id)
if not metrics:
return None, -1.0
vpc = storage.get_campaign_value_per_conversion(self.campaign_id)
# 季節性スコア s ∈ (0, 1)
s = 0.5
if seasonal_fn is not None:
try:
s = float(seasonal_fn(context))
s = max(0.01, min(0.99, s))
except Exception:
s = 0.5
best_score, best_variant = -1.0, None
for row in metrics:
ac, bc = float(row["alpha_click"]), float(row["beta_click"])
av, bv = float(row["alpha_conv"]), float(row["beta_conv"])
# 季節性でクリック側の事前分布を微調整
ac_eff = ac + self.seasonality_boost * s
bc_eff = bc + self.seasonality_boost * (1.0 - s)
pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
pv = random.betavariate(max(1e-6, av), max(1e-6, bv))
score = pc * pv * vpc
if score > best_score:
best_score, best_variant = score, row["variant_id"]
return best_variant, best_score
@staticmethod
def update_with_event(campaign_id: str, variant_id: str, event_type: str):
if event_type == "impression":
storage.update_metric(campaign_id, variant_id, "impressions", 1)
storage.update_metric(campaign_id, variant_id, "beta_click", 1)
elif event_type == "click":
storage.update_metric(campaign_id, variant_id, "clicks", 1)
storage.update_metric(campaign_id, variant_id, "alpha_click", 1)
elif event_type == "conversion":
storage.update_metric(campaign_id, variant_id, "conversions", 1)
storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
else:
raise ValueError("unknown event_type")