Corin1998 commited on
Commit
ab532a1
·
verified ·
1 Parent(s): 5a416a8

Update causal.py

Browse files
Files changed (1) hide show
  1. causal.py +8 -30
causal.py CHANGED
@@ -5,14 +5,6 @@ import pymc as pm
5
  import pytensor.tensor as at
6
  from typing import Dict, Any, Optional
7
 
8
- """
9
- 階層ロジスティック回帰で uplift(コントロールとの差)を推定。
10
- - 目的変数: click (0/1) 推奨。集計単位が日次のときは二項モデルに切替可能。
11
- - 階層: 媒体ランダム効果。
12
- - 処置: creative(control=1 の行を基準)
13
- - 共変量: 任意の数値列(正規化して利用)
14
- """
15
-
16
  def _zscore(df: pd.DataFrame, cols):
17
  x = df[cols].astype(float)
18
  return (x - x.mean()) / (x.std(ddof=0) + 1e-9)
@@ -28,11 +20,8 @@ def fit_uplift_binary(
28
  target_accept: float = 0.9,
29
  random_seed: int = 42,
30
  ) -> Dict[str, Any]:
31
- # 前処理
32
  d = df.copy().reset_index(drop=True)
33
  if outcome_col not in d.columns:
34
- # 集計データ (impr, clicks) から擬似サンプル化
35
- # 二項尤度に切替
36
  d["n"] = d["impressions"].astype(int)
37
  d["y"] = d["clicks"].astype(int)
38
  binomial = True
@@ -41,7 +30,6 @@ def fit_uplift_binary(
41
  d["n"] = 1
42
  binomial = False
43
 
44
- # creative の one-hot と control 基準
45
  creatives = d[creative_col].astype(str).unique().tolist()
46
  control_creatives = d.loc[d[control_flag_col] == 1, creative_col].astype(str).unique().tolist()
47
  control_ref = control_creatives[0] if len(control_creatives) else creatives[0]
@@ -49,7 +37,6 @@ def fit_uplift_binary(
49
  d["creative_idx"] = d[creative_col].astype(str).apply(lambda x: creatives.index(x)).astype(int)
50
  d["medium_idx"] = d[medium_col].astype(str).astype('category').cat.codes.values
51
 
52
- # 特徴量
53
  X = None
54
  if feature_cols:
55
  X = _zscore(d, feature_cols).values
@@ -58,27 +45,24 @@ def fit_uplift_binary(
58
  p = 0
59
 
60
  with pm.Model() as model:
61
- # ランダム効果: 媒体
62
  n_medium = int(d["medium_idx"].max()) + 1
63
  mu_re = pm.Normal("mu_re", 0.0, 1.0)
64
  sd_re = pm.HalfNormal("sd_re", 1.0)
65
  z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
66
  b_medium = pm.Deterministic("b_medium", mu_re + z_re * sd_re)
67
 
68
- # creative 固定効果(baseline = control_ref)
69
  n_creative = len(creatives)
70
  b0 = pm.Normal("intercept", 0.0, 1.5)
71
  b_cre = pm.Normal("b_cre_raw", 0.0, 1.0, shape=n_creative)
72
- # 基準調整
73
  ref_idx = creatives.index(control_ref)
74
  b_cre_adj = at.set_subtensor(b_cre[ref_idx], 0.0)
75
 
76
- # 連続特徴量
77
  if p > 0:
78
  b_x = pm.Normal("b_x", 0.0, 1.0, shape=p)
79
- lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]].values + at.dot(X, b_x)
80
  else:
81
- lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"]].values
82
 
83
  p_click = pm.Deterministic("p_click", pm.math.sigmoid(lin))
84
 
@@ -87,18 +71,16 @@ def fit_uplift_binary(
87
  else:
88
  pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
89
 
90
- idata = pm.sample(draws=draws, tune=draws, chains=2, target_accept=target_accept, random_seed=random_seed, progressbar=False)
 
 
91
 
92
- # uplift: 各 creative の p_click を control_ref と比較
93
  post = idata.posterior
94
- # サンプル次元: chain, draw
95
  b0_s = post["intercept"].stack(sample=("chain", "draw"))
96
  b_cre_s = post["b_cre_raw"].stack(sample=("chain", "draw"))
97
- # uplift は "平均的ユーザー" と "平均的媒体効果" を前提に比較
98
  mu_re = post["mu_re"].stack(sample=("chain", "draw"))
99
 
100
- def sigmoid(x):
101
- return 1 / (1 + np.exp(-x))
102
 
103
  results = []
104
  for cr in creatives:
@@ -113,8 +95,4 @@ def fit_uplift_binary(
113
  "control_ref": control_ref,
114
  })
115
 
116
- return {
117
- "control_ref": control_ref,
118
- "creatives": creatives,
119
- "results": results,
120
- }
 
5
  import pytensor.tensor as at
6
  from typing import Dict, Any, Optional
7
 
 
 
 
 
 
 
 
 
8
  def _zscore(df: pd.DataFrame, cols):
9
  x = df[cols].astype(float)
10
  return (x - x.mean()) / (x.std(ddof=0) + 1e-9)
 
20
  target_accept: float = 0.9,
21
  random_seed: int = 42,
22
  ) -> Dict[str, Any]:
 
23
  d = df.copy().reset_index(drop=True)
24
  if outcome_col not in d.columns:
 
 
25
  d["n"] = d["impressions"].astype(int)
26
  d["y"] = d["clicks"].astype(int)
27
  binomial = True
 
30
  d["n"] = 1
31
  binomial = False
32
 
 
33
  creatives = d[creative_col].astype(str).unique().tolist()
34
  control_creatives = d.loc[d[control_flag_col] == 1, creative_col].astype(str).unique().tolist()
35
  control_ref = control_creatives[0] if len(control_creatives) else creatives[0]
 
37
  d["creative_idx"] = d[creative_col].astype(str).apply(lambda x: creatives.index(x)).astype(int)
38
  d["medium_idx"] = d[medium_col].astype(str).astype('category').cat.codes.values
39
 
 
40
  X = None
41
  if feature_cols:
42
  X = _zscore(d, feature_cols).values
 
45
  p = 0
46
 
47
  with pm.Model() as model:
 
48
  n_medium = int(d["medium_idx"].max()) + 1
49
  mu_re = pm.Normal("mu_re", 0.0, 1.0)
50
  sd_re = pm.HalfNormal("sd_re", 1.0)
51
  z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
52
  b_medium = pm.Deterministic("b_medium", mu_re + z_re * sd_re)
53
 
 
54
  n_creative = len(creatives)
55
  b0 = pm.Normal("intercept", 0.0, 1.5)
56
  b_cre = pm.Normal("b_cre_raw", 0.0, 1.0, shape=n_creative)
57
+
58
  ref_idx = creatives.index(control_ref)
59
  b_cre_adj = at.set_subtensor(b_cre[ref_idx], 0.0)
60
 
 
61
  if p > 0:
62
  b_x = pm.Normal("b_x", 0.0, 1.0, shape=p)
63
+ lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"].values] + at.dot(X, b_x)
64
  else:
65
+ lin = b0 + b_cre_adj[d["creative_idx"].values] + b_medium[d["medium_idx"].values]
66
 
67
  p_click = pm.Deterministic("p_click", pm.math.sigmoid(lin))
68
 
 
71
  else:
72
  pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
73
 
74
+ idata = pm.sample(draws=draws, tune=draws, chains=2,
75
+ target_accept=target_accept, random_seed=random_seed,
76
+ progressbar=False)
77
 
 
78
  post = idata.posterior
 
79
  b0_s = post["intercept"].stack(sample=("chain", "draw"))
80
  b_cre_s = post["b_cre_raw"].stack(sample=("chain", "draw"))
 
81
  mu_re = post["mu_re"].stack(sample=("chain", "draw"))
82
 
83
+ def sigmoid(x): return 1 / (1 + np.exp(-x))
 
84
 
85
  results = []
86
  for cr in creatives:
 
95
  "control_ref": control_ref,
96
  })
97
 
98
+ return {"control_ref": control_ref, "creatives": creatives, "results": results}