File size: 2,369 Bytes
45e6e77
 
 
 
 
 
 
167c1df
 
45e6e77
 
 
 
 
167c1df
 
45e6e77
 
 
 
 
 
 
 
 
 
 
 
 
167c1df
 
 
 
45e6e77
 
 
 
 
 
167c1df
 
45e6e77
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations
import random
from typing import Dict, Any, Tuple, Callable, Optional
from . import storage

class ThompsonBandit:
    """
    CTRとCVRの二段ベータ。EV=pc*pv*V。
    季節性はクリック側の事前にブースト。
    """
    def __init__(self, campaign_id: str, seasonality_boost: float = 5.0):
        self.campaign_id = campaign_id
        self.seasonality_boost = seasonality_boost

    def sample_arm(self, context: Dict[str, Any], seasonal_fn: Optional[Callable[[Dict[str, Any]], float]] = None
                   ) -> Tuple[Optional[str], float]:
        metrics = storage.get_metrics(self.campaign_id)
        if not metrics:
            return None, -1.0
        vpc = storage.get_campaign_value_per_conversion(self.campaign_id)

        s = 0.5
        if seasonal_fn is not None:
            try:
                s = float(seasonal_fn(context))
                s = max(0.01, min(0.99, s))
            except Exception:
                s = 0.5

        best_score, best_vid = -1.0, None
        for r in metrics:
            ac, bc = float(r["alpha_click"]), float(r["beta_click"])
            av, bv = float(r["alpha_conv"]), float(r["beta_conv"])
            ac_eff = ac + self.seasonality_boost * s
            bc_eff = bc + self.seasonality_boost * (1.0 - s)
            pc = random.betavariate(max(1e-6, ac_eff), max(1e-6, bc_eff))
            pv = random.betavariate(max(1e-6, av),     max(1e-6, bv))
            score = pc * pv * vpc
            if score > best_score:
                best_score, best_vid = score, r["variant_id"]
        return best_vid, best_score

    @staticmethod
    def update_with_event(campaign_id: str, variant_id: str, event_type: str):
        if event_type == "impression":
            storage.update_metric(campaign_id, variant_id, "impressions", 1)
            storage.update_metric(campaign_id, variant_id, "beta_click", 1)
        elif event_type == "click":
            storage.update_metric(campaign_id, variant_id, "clicks", 1)
            storage.update_metric(campaign_id, variant_id, "alpha_click", 1)
        elif event_type == "conversion":
            storage.update_metric(campaign_id, variant_id, "conversions", 1)
            storage.update_metric(campaign_id, variant_id, "alpha_conv", 1)
        else:
            raise ValueError("unknown event_type")