Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import numpy as np | |
| import pandas as pd | |
| import pymc as pm | |
| import pytensor.tensor as at | |
| from typing import Dict, Any, Optional | |
| """ | |
| 階層ロジスティック回帰で uplift(コントロールとの差)を推定。 | |
| - 目的変数: click (0/1) 推奨。集計単位が日次のときは二項モデルに切替可能。 | |
| - 階層: 媒体ランダム効果。 | |
| - 処置: creative(control=1 の行を基準) | |
| - 共変量: 任意の数値列(正規化して利用) | |
| """ | |
| def _zscore(df: pd.DataFrame, cols): | |
| x = df[cols].astype(float) | |
| return (x - x.mean()) / (x.std(ddof=0) + 1e-9) | |
| def fit_uplift_binary( | |
| df: pd.DataFrame, | |
| outcome_col: str = "click_bin", | |
| medium_col: str = "medium", | |
| creative_col: str = "creative", | |
| control_flag_col: str = "is_control", | |
| feature_cols: Optional[list] = None, | |
| draws: int = 1000, | |
| target_accept: float = 0.9, | |
| random_seed: int = 42, | |
| ) -> Dict[str, Any]: | |
| # 前処理 | |
| d = df.copy().reset_index(drop=True) | |
| if outcome_col not in d.columns: | |
| # 集計データ (impr, clicks) から擬似サンプル化 | |
| # 二項尤度に切替 | |
| d["n"] = d["impressions"].astype(int) | |
| d["y"] = d["clicks"].astype(int) | |
| binomial = True | |
| else: | |
| d["y"] = d[outcome_col].astype(int) | |
| d["n"] = 1 | |
| binomial = False | |
| # creative の one-hot と control 基準 | |
| creatives = d[creative_col].astype(str).unique().tolist() | |
| control_creatives = d.loc[d[control_flag_col] == 1, creative_col].astype(str).unique().tolist() | |
| control_ref = control_creatives[0] if len(control_creatives) else creatives[0] | |
| d["creative_idx"] = d[creative_col].astype(str).apply(lambda x: creatives.index(x)).astype(int) | |
| d["medium_idx"] = d[medium_col].astype(str).astype('category').cat.codes.values | |
| # 特徴量 | |
| X = None | |
| if feature_cols: | |
| X = _zscore(d, feature_cols).values | |
| p = X.shape[1] | |
| else: | |
| p = 0 | |
| with pm.Model() as model: | |
| # ランダム効果: 媒体 | |
| n_medium = int(d["medium_idx"].max()) + 1 | |
| mu_re = pm.Normal("mu_re", 0.0, 1.0) | |
| sd_re = pm.HalfNormal("sd_re", 1.0) | |
| z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium) | |
| b_medium = pm.Deterministic("b_medium", mu_re + z_re * sd_re) | |
| # creative 固定効果(baseline = control_ref) | |
| n_creative = len(creatives) | |
| b0 = pm.Normal("intercept", 0.0, 1.5) | |
| b_cre = pm.Normal("b_cre_raw", 0.0, 1.0, shape=n_creative) | |
| # 基準調整 | |
| ref_idx = creatives.index(control_ref) | |
| b_cre_adj = at.set_subtensor(b_cre[ref_idx], 0.0) | |
| # 連続特徴量 | |
| if p > 0: | |
| b_x = pm.Normal("b_x", 0.0, 1.0, shape=p) | |
| lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]].values + at.dot(X, b_x) | |
| else: | |
| lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]].values | |
| p_click = pm.Deterministic("p_click", pm.math.sigmoid(lin)) | |
| if binomial: | |
| pm.Binomial("y_obs", n=d["n"].values, p=p_click, observed=d["y"].values) | |
| else: | |
| pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values) | |
| idata = pm.sample(draws=draws, tune=draws, chains=2, target_accept=target_accept, random_seed=random_seed, progressbar=False) | |
| # uplift: 各 creative の p_click を control_ref と比較 | |
| post = idata.posterior | |
| # サンプル次元: chain, draw | |
| b0_s = post["intercept"].stack(sample=("chain", "draw")) | |
| b_cre_s = post["b_cre_raw"].stack(sample=("chain", "draw")) | |
| # uplift は "平均的ユーザー" と "平均的媒体効果" を前提に比較 | |
| mu_re = post["mu_re"].stack(sample=("chain", "draw")) | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| results = [] | |
| for cr in creatives: | |
| idx = creatives.index(cr) | |
| lin_t = b0_s + b_cre_s.isel(b_cre_raw_dim_0=idx) + mu_re | |
| lin_c = b0_s + b_cre_s.isel(b_cre_raw_dim_0=ref_idx) + mu_re | |
| uplift = sigmoid(lin_t) - sigmoid(lin_c) | |
| results.append({ | |
| "creative": cr, | |
| "uplift_mean": float(uplift.mean().item()), | |
| "uplift_p_gt0": float((uplift > 0).mean().item()), | |
| "control_ref": control_ref, | |
| }) | |
| return { | |
| "control_ref": control_ref, | |
| "creatives": creatives, | |
| "results": results, | |
| } |