Corin1998 commited on
Commit
f82b1da
·
verified ·
1 Parent(s): 46cda09

Update causal.py

Browse files
Files changed (1) hide show
  1. causal.py +19 -4
causal.py CHANGED
@@ -1,4 +1,17 @@
1
  from __future__ import annotations
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import pymc as pm
@@ -45,7 +58,7 @@ def fit_uplift_binary(
45
  p = 0
46
 
47
  with pm.Model() as model:
48
- n_medium = int(d["medium_idx"].max()) + 1 if hasattr(d["medium_idx"], "max") else int(np.max(d["medium_idx"])) + 1
49
  mu_re = pm.Normal("mu_re", 0.0, 1.0)
50
  sd_re = pm.HalfNormal("sd_re", 1.0)
51
  z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
@@ -71,9 +84,11 @@ def fit_uplift_binary(
71
  else:
72
  pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
73
 
74
- idata = pm.sample(draws=draws, tune=draws, chains=2,
75
- target_accept=target_accept, random_seed=random_seed,
76
- progressbar=False)
 
 
77
 
78
  post = idata.posterior
79
  b0_s = post["intercept"].stack(sample=("chain", "draw"))
 
1
  from __future__ import annotations
2
+ import os
3
+
4
+ # ---- PyTensor(=PyMCの内部) のコンパイルキャッシュ先を /tmp に固定 ----
5
+ # 環境変数で上書き可: PYTENSOR_BASE=/data/.pytensor(永続化したい場合)
6
+ base = os.environ.get("PYTENSOR_BASE", "/tmp/.pytensor")
7
+ os.makedirs(base, exist_ok=True)
8
+ if "PYTENSOR_FLAGS" in os.environ and "base_compiledir=" not in os.environ["PYTENSOR_FLAGS"]:
9
+ os.environ["PYTENSOR_FLAGS"] = os.environ["PYTENSOR_FLAGS"] + f",base_compiledir={base}"
10
+ elif "PYTENSOR_FLAGS" not in os.environ:
11
+ os.environ["PYTENSOR_FLAGS"] = f"base_compiledir={base}"
12
+ os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
13
+
14
+ # ---- 以降は従来どおり ----
15
  import numpy as np
16
  import pandas as pd
17
  import pymc as pm
 
58
  p = 0
59
 
60
  with pm.Model() as model:
61
+ n_medium = int(np.max(d["medium_idx"])) + 1
62
  mu_re = pm.Normal("mu_re", 0.0, 1.0)
63
  sd_re = pm.HalfNormal("sd_re", 1.0)
64
  z_re = pm.Normal("z_re", 0.0, 1.0, shape=n_medium)
 
84
  else:
85
  pm.Bernoulli("y_obs", p=p_click, observed=d["y"].values)
86
 
87
+ idata = pm.sample(
88
+ draws=draws, tune=draws, chains=2,
89
+ target_accept=target_accept, random_seed=random_seed,
90
+ progressbar=False
91
+ )
92
 
93
  post = idata.posterior
94
  b0_s = post["intercept"].stack(sample=("chain", "draw"))