Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import random | |
| from typing import Dict, Any, Tuple, Callable, Optional | |
| from . import storage | |
| class ThompsonBandit: | |
| """ | |
| CTR(クリック)とCVR(コンバージョン)の二段ベータ。 | |
| 目的関数: E[value] = p_click * p_conv * value_per_conversion | |
| 季節性は CTR 側の仮想カウントで補正。 | |
| """ | |
| def __init__(self, campaign_id: str, seasonality_boost: float = 5.0): | |
| self.campaign_id = campaign_id | |
| self.seasonality_boost = seasonality_boost | |
| def sample_arm( | |
| self, | |
| context: Dict[str, Any], | |
| seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None, # ← キーワード引数OK | |
| ) -> Tuple[Optional[str], float]: | |
| metrics = storage.get_metrics(self.campaign_id) | |
| if not metrics: | |
| return None, -1.0 | |
| vpc = storage.get_campaign_value_per_conversion(self.campaign_id) | |
| # 季節性スコア s ∈ (0, 1) | |
| s = 0.5 | |
| if seasonal_fn is not None: | |
| try: | |
| s = float(seasonal_fn(context)) | |
| s = max(0.01, min(0.99, s)) | |
| except Exception: | |
| s = 0.5 | |
| best_score, best_variant = -1.0, None | |
| for row in metrics: | |
| ac, bc = float(row["alpha_click"]), float(row["beta_click"]) | |
| av, bv = float(row["alpha_conv"]), float(row["beta_conv"]) | |
| # 季節性でクリック側の事前分布を微調整 | |
| ac_eff = ac + self.seasonality_boost * s | |
| bc_eff = bc + self.seasonality_boost * (1.0 - s) | |
| pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff)) | |
| pv = random.betavariate(max(1e-6, av), max(1e-6, bv)) | |
| score = pc * pv * vpc | |
| if score > best_score: | |
| best_score, best_variant = score, row["variant_id"] | |
| return best_variant, best_score | |
| def update_with_event(campaign_id: str, variant_id: str, event_type: str): | |
| if event_type == "impression": | |
| storage.update_metric(campaign_id, variant_id, "impressions", 1) | |
| storage.update_metric(campaign_id, variant_id, "beta_click", 1) | |
| elif event_type == "click": | |
| storage.update_metric(campaign_id, variant_id, "clicks", 1) | |
| storage.update_metric(campaign_id, variant_id, "alpha_click", 1) | |
| elif event_type == "conversion": | |
| storage.update_metric(campaign_id, variant_id, "conversions", 1) | |
| storage.update_metric(campaign_id, variant_id, "alpha_conv", 1) | |
| else: | |
| raise ValueError("unknown event_type") | |