Spaces:
Sleeping
Sleeping
Update causal.py
Browse files
causal.py
CHANGED
|
@@ -5,14 +5,6 @@ import pymc as pm
|
|
| 5 |
import pytensor.tensor as at
|
| 6 |
from typing import Dict, Any, Optional
|
| 7 |
|
| 8 |
-
"""
|
| 9 |
-
階層ロジスティック回帰で uplift(コントロールとの差)を推定。
|
| 10 |
-
- 目的変数: click (0/1) 推奨。集計単位が日次のときは二項モデルに切替可能。
|
| 11 |
-
- 階層: 媒体ランダム効果。
|
| 12 |
-
- 処置: creative(control=1 の行を基準)
|
| 13 |
-
- 共変量: 任意の数値列(正規化して利用)
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
def _zscore(df: pd.DataFrame, cols):
|
| 17 |
x = df[cols].astype(float)
|
| 18 |
return (x - x.mean()) / (x.std(ddof=0) + 1e-9)
|
|
@@ -28,11 +20,8 @@ def fit_uplift_binary(
|
|
| 28 |
target_accept: float = 0.9,
|
| 29 |
random_seed: int = 42,
|
| 30 |
) -> Dict[str, Any]:
|
| 31 |
-
# 前処理
|
| 32 |
d = df.copy().reset_index(drop=True)
|
| 33 |
if outcome_col not in d.columns:
|
| 34 |
-
# 集計データ (impr, clicks) から擬似サンプル化
|
| 35 |
-
# 二項尤度に切替
|
| 36 |
d["n"] = d["impressions"].astype(int)
|
| 37 |
d["y"] = d["clicks"].astype(int)
|
| 38 |
binomial = True
|
|
@@ -41,7 +30,6 @@ def fit_uplift_binary(
|
|
| 41 |
d["n"] = 1
|
| 42 |
binomial = False
|
| 43 |
|
| 44 |
-
# creative の one-hot と control 基準
|
| 45 |
creatives = d[creative_col].astype(str).unique().tolist()
|
| 46 |
control_creatives = d.loc[d[control_flag_col] == 1, creative_col].astype(str).unique().tolist()
|
| 47 |
control_ref = control_creatives[0] if len(control_creatives) else creatives[0]
|
|
@@ -49,7 +37,6 @@ def fit_uplift_binary(
|
|
| 49 |
d["creative_idx"] = d[creative_col].astype(str).apply(lambda x: creatives.index(x)).astype(int)
|
| 50 |
d["medium_idx"] = d[medium_col].astype(str).astype('category').cat.codes.values
|
| 51 |
|
| 52 |
-
# 特徴量
|
| 53 |
X = None
|
| 54 |
if feature_cols:
|
| 55 |
X = _zscore(d, feature_cols).values
|
|
@@ -58,27 +45,24 @@ def fit_uplift_binary(
|
|
| 58 |
p = 0
|
| 59 |
|
| 60 |
with pm.Model() as model:
|
| 61 |
-
# ランダム効果: 媒体
|
| 62 |
n_medium = int(d["medium_idx"].max()) + 1
|
| 63 |
mu_re = pm.Normal("mu_re", 0.0, 1.0)
|
| 64 |
sd_re = pm.HalfNormal("sd_re", 1.0)
|
| 65 |
z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
|
| 66 |
b_medium = pm.Deterministic("b_medium", mu_re + z_re * sd_re)
|
| 67 |
|
| 68 |
-
# creative 固定効果(baseline = control_ref)
|
| 69 |
n_creative = len(creatives)
|
| 70 |
b0 = pm.Normal("intercept", 0.0, 1.5)
|
| 71 |
b_cre = pm.Normal("b_cre_raw", 0.0, 1.0, shape=n_creative)
|
| 72 |
-
|
| 73 |
ref_idx = creatives.index(control_ref)
|
| 74 |
b_cre_adj = at.set_subtensor(b_cre[ref_idx], 0.0)
|
| 75 |
|
| 76 |
-
# 連続特徴量
|
| 77 |
if p > 0:
|
| 78 |
b_x = pm.Normal("b_x", 0.0, 1.0, shape=p)
|
| 79 |
-
lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]
|
| 80 |
else:
|
| 81 |
-
lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]
|
| 82 |
|
| 83 |
p_click = pm.Deterministic("p_click", pm.math.sigmoid(lin))
|
| 84 |
|
|
@@ -87,18 +71,16 @@ def fit_uplift_binary(
|
|
| 87 |
else:
|
| 88 |
pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
|
| 89 |
|
| 90 |
-
idata = pm.sample(draws=draws, tune=draws, chains=2,
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
# uplift: 各 creative の p_click を control_ref と比較
|
| 93 |
post = idata.posterior
|
| 94 |
-
# サンプル次元: chain, draw
|
| 95 |
b0_s = post["intercept"].stack(sample=("chain", "draw"))
|
| 96 |
b_cre_s = post["b_cre_raw"].stack(sample=("chain", "draw"))
|
| 97 |
-
# uplift は "平均的ユーザー" と "平均的媒体効果" を前提に比較
|
| 98 |
mu_re = post["mu_re"].stack(sample=("chain", "draw"))
|
| 99 |
|
| 100 |
-
def sigmoid(x):
|
| 101 |
-
return 1 / (1 + np.exp(-x))
|
| 102 |
|
| 103 |
results = []
|
| 104 |
for cr in creatives:
|
|
@@ -113,8 +95,4 @@ def fit_uplift_binary(
|
|
| 113 |
"control_ref": control_ref,
|
| 114 |
})
|
| 115 |
|
| 116 |
-
return {
|
| 117 |
-
"control_ref": control_ref,
|
| 118 |
-
"creatives": creatives,
|
| 119 |
-
"results": results,
|
| 120 |
-
}
|
|
|
|
| 5 |
import pytensor.tensor as at
|
| 6 |
from typing import Dict, Any, Optional
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def _zscore(df: pd.DataFrame, cols):
|
| 9 |
x = df[cols].astype(float)
|
| 10 |
return (x - x.mean()) / (x.std(ddof=0) + 1e-9)
|
|
|
|
| 20 |
target_accept: float = 0.9,
|
| 21 |
random_seed: int = 42,
|
| 22 |
) -> Dict[str, Any]:
|
|
|
|
| 23 |
d = df.copy().reset_index(drop=True)
|
| 24 |
if outcome_col not in d.columns:
|
|
|
|
|
|
|
| 25 |
d["n"] = d["impressions"].astype(int)
|
| 26 |
d["y"] = d["clicks"].astype(int)
|
| 27 |
binomial = True
|
|
|
|
| 30 |
d["n"] = 1
|
| 31 |
binomial = False
|
| 32 |
|
|
|
|
| 33 |
creatives = d[creative_col].astype(str).unique().tolist()
|
| 34 |
control_creatives = d.loc[d[control_flag_col] == 1, creative_col].astype(str).unique().tolist()
|
| 35 |
control_ref = control_creatives[0] if len(control_creatives) else creatives[0]
|
|
|
|
| 37 |
d["creative_idx"] = d[creative_col].astype(str).apply(lambda x: creatives.index(x)).astype(int)
|
| 38 |
d["medium_idx"] = d[medium_col].astype(str).astype('category').cat.codes.values
|
| 39 |
|
|
|
|
| 40 |
X = None
|
| 41 |
if feature_cols:
|
| 42 |
X = _zscore(d, feature_cols).values
|
|
|
|
| 45 |
p = 0
|
| 46 |
|
| 47 |
with pm.Model() as model:
|
|
|
|
| 48 |
n_medium = int(d["medium_idx"].max()) + 1
|
| 49 |
mu_re = pm.Normal("mu_re", 0.0, 1.0)
|
| 50 |
sd_re = pm.HalfNormal("sd_re", 1.0)
|
| 51 |
z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
|
| 52 |
b_medium = pm.Deterministic("b_medium", mu_re + z_re * sd_re)
|
| 53 |
|
|
|
|
| 54 |
n_creative = len(creatives)
|
| 55 |
b0 = pm.Normal("intercept", 0.0, 1.5)
|
| 56 |
b_cre = pm.Normal("b_cre_raw", 0.0, 1.0, shape=n_creative)
|
| 57 |
+
|
| 58 |
ref_idx = creatives.index(control_ref)
|
| 59 |
b_cre_adj = at.set_subtensor(b_cre[ref_idx], 0.0)
|
| 60 |
|
|
|
|
| 61 |
if p > 0:
|
| 62 |
b_x = pm.Normal("b_x", 0.0, 1.0, shape=p)
|
| 63 |
+
lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"].values] + at.dot(X, b_x)
|
| 64 |
else:
|
| 65 |
+
lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"].values]
|
| 66 |
|
| 67 |
p_click = pm.Deterministic("p_click", pm.math.sigmoid(lin))
|
| 68 |
|
|
|
|
| 71 |
else:
|
| 72 |
pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
|
| 73 |
|
| 74 |
+
idata = pm.sample(draws=draws, tune=draws, chains=2,
|
| 75 |
+
target_accept=target_accept, random_seed=random_seed,
|
| 76 |
+
progressbar=False)
|
| 77 |
|
|
|
|
| 78 |
post = idata.posterior
|
|
|
|
| 79 |
b0_s = post["intercept"].stack(sample=("chain", "draw"))
|
| 80 |
b_cre_s = post["b_cre_raw"].stack(sample=("chain", "draw"))
|
|
|
|
| 81 |
mu_re = post["mu_re"].stack(sample=("chain", "draw"))
|
| 82 |
|
| 83 |
+
def sigmoid(x): return 1 / (1 + np.exp(-x))
|
|
|
|
| 84 |
|
| 85 |
results = []
|
| 86 |
for cr in creatives:
|
|
|
|
| 95 |
"control_ref": control_ref,
|
| 96 |
})
|
| 97 |
|
| 98 |
+
return {"control_ref": control_ref, "creatives": creatives, "results": results}
|
|
|
|
|
|
|
|
|
|
|
|