Spaces:
Sleeping
Sleeping
Update causal.py
Browse files
causal.py
CHANGED
|
@@ -1,4 +1,17 @@
|
|
| 1 |
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import pymc as pm
|
|
@@ -45,7 +58,7 @@ def fit_uplift_binary(
|
|
| 45 |
p = 0
|
| 46 |
|
| 47 |
with pm.Model() as model:
|
| 48 |
-
n_medium = int(
|
| 49 |
mu_re = pm.Normal("mu_re", 0.0, 1.0)
|
| 50 |
sd_re = pm.HalfNormal("sd_re", 1.0)
|
| 51 |
z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
|
|
@@ -71,9 +84,11 @@ def fit_uplift_binary(
|
|
| 71 |
else:
|
| 72 |
pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
|
| 73 |
|
| 74 |
-
idata = pm.sample(
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
post = idata.posterior
|
| 79 |
b0_s = post["intercept"].stack(sample=("chain", "draw"))
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# ---- PyTensor(=PyMCの内部) のコンパイルキャッシュ先を /tmp に固定 ----
|
| 5 |
+
# 環境変数で上書き可: PYTENSOR_BASE=/data/.pytensor(永続化したい場合)
|
| 6 |
+
base = os.environ.get("PYTENSOR_BASE", "/tmp/.pytensor")
|
| 7 |
+
os.makedirs(base, exist_ok=True)
|
| 8 |
+
if "PYTENSOR_FLAGS" in os.environ and "base_compiledir=" not in os.environ["PYTENSOR_FLAGS"]:
|
| 9 |
+
os.environ["PYTENSOR_FLAGS"] = os.environ["PYTENSOR_FLAGS"] + f",base_compiledir={base}"
|
| 10 |
+
elif "PYTENSOR_FLAGS" not in os.environ:
|
| 11 |
+
os.environ["PYTENSOR_FLAGS"] = f"base_compiledir={base}"
|
| 12 |
+
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
|
| 13 |
+
|
| 14 |
+
# ---- 以降は従来どおり ----
|
| 15 |
import numpy as np
|
| 16 |
import pandas as pd
|
| 17 |
import pymc as pm
|
|
|
|
| 58 |
p = 0
|
| 59 |
|
| 60 |
with pm.Model() as model:
|
| 61 |
+
n_medium = int(np.max(d["medium_idx"])) + 1
|
| 62 |
mu_re = pm.Normal("mu_re", 0.0, 1.0)
|
| 63 |
sd_re = pm.HalfNormal("sd_re", 1.0)
|
| 64 |
z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
|
|
|
|
| 84 |
else:
|
| 85 |
pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
|
| 86 |
|
| 87 |
+
idata = pm.sample(
|
| 88 |
+
draws=draws, tune=draws, chains=2,
|
| 89 |
+
target_accept=target_accept, random_seed=random_seed,
|
| 90 |
+
progressbar=False
|
| 91 |
+
)
|
| 92 |
|
| 93 |
post = idata.posterior
|
| 94 |
b0_s = post["intercept"].stack(sample=("chain", "draw"))
|