Corin1998 commited on
Commit
d5be10a
·
verified ·
1 Parent(s): 6a62501

Update app/storage.py

Browse files
Files changed (1) hide show
  1. app/storage.py +203 -28
app/storage.py CHANGED
@@ -1,11 +1,8 @@
1
  from __future__ import annotations
2
- import os
3
- import sqlite3
4
- import json
5
  from pathlib import Path
6
- from typing import Optional, Dict, Any, List
7
 
8
- # 永続領域 /data が書き込み可なら /data/app_data を使う。なければ /tmp/app_data へフォールバック。
9
  DEFAULT_DIR = "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data"
10
  DB_DIR = Path(os.environ.get("APP_DATA_DIR", DEFAULT_DIR))
11
  DB_DIR.mkdir(parents=True, exist_ok=True)
@@ -21,7 +18,12 @@ CREATE TABLE IF NOT EXISTS campaigns (
21
  tone TEXT,
22
  language TEXT,
23
  constraints_json TEXT,
24
- value_per_conversion REAL DEFAULT 1.0
 
 
 
 
 
25
  );
26
  CREATE TABLE IF NOT EXISTS variants (
27
  campaign_id TEXT,
@@ -51,6 +53,36 @@ CREATE TABLE IF NOT EXISTS events (
51
  ts TEXT,
52
  value REAL
53
  );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
 
56
  def get_conn():
@@ -58,20 +90,34 @@ def get_conn():
58
  conn.row_factory = sqlite3.Row
59
  return conn
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def init_db():
62
  with get_conn() as con:
63
  con.executescript(SCHEMA_SQL)
 
64
 
65
- def upsert_campaign(
66
- campaign_id: str,
67
- brand: str,
68
- product: str,
69
- target_audience: str,
70
- tone: str,
71
- language: str,
72
- constraints: Optional[Dict[str, Any]],
73
- value_per_conversion: float,
74
- ):
75
  with get_conn() as con:
76
  con.execute(
77
  """
@@ -86,9 +132,24 @@ def upsert_campaign(
86
  constraints_json=excluded.constraints_json,
87
  value_per_conversion=excluded.value_per_conversion
88
  """,
89
- (campaign_id, brand, product, target_audience, tone, language, json.dumps(constraints or {}), value_per_conversion),
 
 
 
 
 
 
 
 
 
 
90
  )
91
 
 
 
 
 
 
92
  def insert_variant(campaign_id: str, variant_id: str, text: str, status: str, rejection_reason: Optional[str]):
93
  with get_conn() as con:
94
  con.execute(
@@ -99,23 +160,20 @@ def insert_variant(campaign_id: str, variant_id: str, text: str, status: str, re
99
  (campaign_id, variant_id, text, status, rejection_reason),
100
  )
101
  con.execute(
102
- """
103
- INSERT OR IGNORE INTO metrics (campaign_id, variant_id)
104
- VALUES (?, ?)
105
- """,
106
- (campaign_id, variant_id),
107
  )
108
 
 
 
 
 
 
109
  def get_variants(campaign_id: str) -> List[sqlite3.Row]:
110
  with get_conn() as con:
111
  cur = con.execute("SELECT * FROM variants WHERE campaign_id=?", (campaign_id,))
112
  return cur.fetchall()
113
 
114
- def get_variant(campaign_id: str, variant_id: str) -> Optional[sqlite3.Row]:
115
- with get_conn() as con:
116
- cur = con.execute("SELECT * FROM variants WHERE campaign_id=? AND variant_id=?", (campaign_id, variant_id))
117
- return cur.fetchone()
118
-
119
  def get_metrics(campaign_id: str) -> List[sqlite3.Row]:
120
  with get_conn() as con:
121
  cur = con.execute("SELECT * FROM metrics WHERE campaign_id=?", (campaign_id,))
@@ -130,7 +188,7 @@ def log_event(campaign_id: str, variant_id: str, event_type: str, ts: str, value
130
  with get_conn() as con:
131
  con.execute(
132
  "INSERT INTO events (campaign_id, variant_id, event_type, ts, value) VALUES (?, ?, ?, ?, ?)",
133
- (campaign_id, variant_id, event_type, ts, value),
134
  )
135
 
136
  def get_campaign_value_per_conversion(campaign_id: str) -> float:
@@ -138,3 +196,120 @@ def get_campaign_value_per_conversion(campaign_id: str) -> float:
138
  cur = con.execute("SELECT value_per_conversion FROM campaigns WHERE campaign_id=?", (campaign_id,))
139
  row = cur.fetchone()
140
  return float(row[0]) if row else 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+ import os, csv, json, sqlite3
 
 
3
  from pathlib import Path
4
+ from typing import Optional, Dict, Any, List, Tuple
5
 
 
6
  DEFAULT_DIR = "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data"
7
  DB_DIR = Path(os.environ.get("APP_DATA_DIR", DEFAULT_DIR))
8
  DB_DIR.mkdir(parents=True, exist_ok=True)
 
18
  tone TEXT,
19
  language TEXT,
20
  constraints_json TEXT,
21
+ value_per_conversion REAL DEFAULT 1.0,
22
+ policy TEXT DEFAULT 'thompson',
23
+ holdout_ratio REAL DEFAULT 0.0,
24
+ stop_min_impressions INTEGER DEFAULT 200,
25
+ stop_rel_ev_threshold REAL DEFAULT 0.5,
26
+ created_at TEXT DEFAULT (datetime('now'))
27
  );
28
  CREATE TABLE IF NOT EXISTS variants (
29
  campaign_id TEXT,
 
53
  ts TEXT,
54
  value REAL
55
  );
56
+ -- LinUCB のパラメータをJSONで保持
57
+ CREATE TABLE IF NOT EXISTS linucb (
58
+ campaign_id TEXT,
59
+ variant_id TEXT,
60
+ d INTEGER,
61
+ A_json TEXT,
62
+ b_json TEXT,
63
+ n_updates INTEGER DEFAULT 0,
64
+ PRIMARY KEY (campaign_id, variant_id)
65
+ );
66
+ -- コンプライアンス監査ログ(NG詳細/LLM修正案)
67
+ CREATE TABLE IF NOT EXISTS compliance_logs (
68
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
69
+ campaign_id TEXT,
70
+ variant_id TEXT,
71
+ status TEXT,
72
+ ng_rules_json TEXT,
73
+ llm_ok INTEGER,
74
+ llm_reasons_json TEXT,
75
+ llm_fixed TEXT,
76
+ ts TEXT DEFAULT (datetime('now'))
77
+ );
78
+ -- 任意の運用監査ログ
79
+ CREATE TABLE IF NOT EXISTS audit_logs (
80
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
81
+ campaign_id TEXT,
82
+ action TEXT,
83
+ payload_json TEXT,
84
+ ts TEXT DEFAULT (datetime('now'))
85
+ );
86
  """
87
 
88
  def get_conn():
 
90
  conn.row_factory = sqlite3.Row
91
  return conn
92
 
93
+ def _ensure_columns():
94
+ need_cols = {
95
+ "campaigns": [
96
+ ("policy", "TEXT", "'thompson'"),
97
+ ("holdout_ratio", "REAL", "0.0"),
98
+ ("stop_min_impressions", "INTEGER", "200"),
99
+ ("stop_rel_ev_threshold", "REAL", "0.5"),
100
+ ("created_at", "TEXT", "datetime('now')"),
101
+ ]
102
+ }
103
+ with get_conn() as con:
104
+ for table, cols in need_cols.items():
105
+ cur = con.execute(f"PRAGMA table_info({table})")
106
+ have = {r["name"] for r in cur.fetchall()}
107
+ for name, typ, default in cols:
108
+ if name not in have:
109
+ con.execute(f"ALTER TABLE {table} ADD COLUMN {name} {typ} DEFAULT ({default})")
110
+
111
  def init_db():
112
  with get_conn() as con:
113
  con.executescript(SCHEMA_SQL)
114
+ _ensure_columns()
115
 
116
+ # ============== Campaign/Variant/Metric 基本 ==============
117
+
118
+ def upsert_campaign(campaign_id: str, brand: str, product: str, target_audience: str,
119
+ tone: str, language: str, constraints: Optional[Dict[str, Any]],
120
+ value_per_conversion: float):
 
 
 
 
 
121
  with get_conn() as con:
122
  con.execute(
123
  """
 
132
  constraints_json=excluded.constraints_json,
133
  value_per_conversion=excluded.value_per_conversion
134
  """,
135
+ (campaign_id, brand, product, target_audience, tone, language, json.dumps(constraints or {}, ensure_ascii=False), value_per_conversion),
136
+ )
137
+
138
+ def set_campaign_settings(campaign_id: str, policy: str, holdout_ratio: float, stop_min_impressions: int, stop_rel_ev_threshold: float):
139
+ with get_conn() as con:
140
+ con.execute(
141
+ """
142
+ UPDATE campaigns SET policy=?, holdout_ratio=?, stop_min_impressions=?, stop_rel_ev_threshold=?
143
+ WHERE campaign_id=?
144
+ """,
145
+ (policy, float(holdout_ratio), int(stop_min_impressions), float(stop_rel_ev_threshold), campaign_id)
146
  )
147
 
148
+ def get_campaign(campaign_id: str):
149
+ with get_conn() as con:
150
+ cur = con.execute("SELECT * FROM campaigns WHERE campaign_id=?", (campaign_id,))
151
+ return cur.fetchone()
152
+
153
  def insert_variant(campaign_id: str, variant_id: str, text: str, status: str, rejection_reason: Optional[str]):
154
  with get_conn() as con:
155
  con.execute(
 
160
  (campaign_id, variant_id, text, status, rejection_reason),
161
  )
162
  con.execute(
163
+ "INSERT OR IGNORE INTO metrics (campaign_id, variant_id) VALUES (?, ?)",
164
+ (campaign_id, variant_id)
 
 
 
165
  )
166
 
167
+ def get_variant(campaign_id: str, variant_id: str):
168
+ with get_conn() as con:
169
+ cur = con.execute("SELECT * FROM variants WHERE campaign_id=? AND variant_id=?", (campaign_id, variant_id))
170
+ return cur.fetchone()
171
+
172
  def get_variants(campaign_id: str) -> List[sqlite3.Row]:
173
  with get_conn() as con:
174
  cur = con.execute("SELECT * FROM variants WHERE campaign_id=?", (campaign_id,))
175
  return cur.fetchall()
176
 
 
 
 
 
 
177
  def get_metrics(campaign_id: str) -> List[sqlite3.Row]:
178
  with get_conn() as con:
179
  cur = con.execute("SELECT * FROM metrics WHERE campaign_id=?", (campaign_id,))
 
188
  with get_conn() as con:
189
  con.execute(
190
  "INSERT INTO events (campaign_id, variant_id, event_type, ts, value) VALUES (?, ?, ?, ?, ?)",
191
+ (campaign_id, variant_id, event_type, ts, value)
192
  )
193
 
194
  def get_campaign_value_per_conversion(campaign_id: str) -> float:
 
196
  cur = con.execute("SELECT value_per_conversion FROM campaigns WHERE campaign_id=?", (campaign_id,))
197
  row = cur.fetchone()
198
  return float(row[0]) if row else 1.0
199
+
200
+ # ============== Compliance / Audit ==============
201
+
202
+ def record_compliance_log(campaign_id: str, variant_id: str, status: str,
203
+ ng_rules: List[str], llm_ok: bool, llm_reasons: List[str], llm_fixed: Optional[str]):
204
+ with get_conn() as con:
205
+ con.execute(
206
+ """
207
+ INSERT INTO compliance_logs (campaign_id, variant_id, status, ng_rules_json, llm_ok, llm_reasons_json, llm_fixed)
208
+ VALUES (?, ?, ?, ?, ?, ?, ?)
209
+ """,
210
+ (campaign_id, variant_id, status, json.dumps(ng_rules, ensure_ascii=False), int(llm_ok),
211
+ json.dumps(llm_reasons, ensure_ascii=False), llm_fixed)
212
+ )
213
+
214
+ def audit(campaign_id: str, action: str, payload: Dict[str, Any]):
215
+ with get_conn() as con:
216
+ con.execute(
217
+ "INSERT INTO audit_logs (campaign_id, action, payload_json) VALUES (?, ?, ?)",
218
+ (campaign_id, action, json.dumps(payload, ensure_ascii=False))
219
+ )
220
+
221
+ # ============== LinUCB state ==============
222
+
223
+ def get_linucb_state(campaign_id: str, variant_id: str):
224
+ with get_conn() as con:
225
+ cur = con.execute("SELECT d, A_json, b_json, n_updates FROM linucb WHERE campaign_id=? AND variant_id=?",
226
+ (campaign_id, variant_id))
227
+ row = cur.fetchone()
228
+ return row
229
+
230
+ def upsert_linucb_state(campaign_id: str, variant_id: str, d: int, A_json: str, b_json: str, n_updates: int):
231
+ with get_conn() as con:
232
+ con.execute(
233
+ """
234
+ INSERT INTO linucb (campaign_id, variant_id, d, A_json, b_json, n_updates)
235
+ VALUES (?, ?, ?, ?, ?, ?)
236
+ ON CONFLICT(campaign_id, variant_id) DO UPDATE SET
237
+ d=excluded.d, A_json=excluded.A_json, b_json=excluded.b_json, n_updates=excluded.n_updates
238
+ """,
239
+ (campaign_id, variant_id, d, A_json, b_json, n_updates)
240
+ )
241
+
242
+ # ============== Export / Reset / Stop rules ==============
243
+
244
+ def export_csv(campaign_id: str, table: str) -> str:
245
+ assert table in {"events", "metrics", "variants", "compliance_logs", "audit_logs"}
246
+ out_dir = DB_DIR / "export"
247
+ out_dir.mkdir(parents=True, exist_ok=True)
248
+ out_path = out_dir / f"{campaign_id}_{table}.csv"
249
+ with get_conn() as con, open(out_path, "w", newline="", encoding="utf-8") as f:
250
+ cur = con.execute(f"SELECT * FROM {table} WHERE campaign_id=? ORDER BY rowid ASC", (campaign_id,))
251
+ rows = cur.fetchall()
252
+ if not rows:
253
+ f.write("") # 空でもファイルは作る
254
+ return str(out_path)
255
+ fieldnames = rows[0].keys()
256
+ w = csv.DictWriter(f, fieldnames=fieldnames)
257
+ w.writeheader()
258
+ for r in rows:
259
+ w.writerow({k: r[k] for k in fieldnames})
260
+ return str(out_path)
261
+
262
+ def reset_all():
263
+ # 破壊的操作:全テーブル初期化
264
+ with get_conn() as con:
265
+ con.executescript("""
266
+ DROP TABLE IF EXISTS linucb;
267
+ DROP TABLE IF EXISTS compliance_logs;
268
+ DROP TABLE IF EXISTS audit_logs;
269
+ DROP TABLE IF EXISTS events;
270
+ DROP TABLE IF EXISTS metrics;
271
+ DROP TABLE IF EXISTS variants;
272
+ DROP TABLE IF EXISTS campaigns;
273
+ """)
274
+ init_db()
275
+
276
+ def evaluate_stop_rules(campaign_id: str) -> List[Tuple[str, str]]:
277
+ """
278
+ 撤退基準:
279
+ - impressions >= stop_min_impressions
280
+ - EV(CTRmean*CVRmean*V) がベストの stop_rel_ev_threshold 倍未満 → pause
281
+ 返り値: [(variant_id, reason), ...] (pause されたもの)
282
+ """
283
+ cfg = get_campaign(campaign_id)
284
+ if not cfg:
285
+ return []
286
+ min_imp = int(cfg["stop_min_impressions"] or 200)
287
+ thresh = float(cfg["stop_rel_ev_threshold"] or 0.5)
288
+ vpc = float(cfg["value_per_conversion"] or 1.0)
289
+
290
+ mets = get_metrics(campaign_id)
291
+ if not mets:
292
+ return []
293
+
294
+ def ev_of(r):
295
+ imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
296
+ ctr = (clk / imp) if imp > 0 else 0.0
297
+ cvr = (conv / max(1, clk)) if clk > 0 else 0.0
298
+ return ctr * cvr * vpc
299
+
300
+ best_ev = max((ev_of(r) for r in mets), default=0.0)
301
+ paused = []
302
+ with get_conn() as con:
303
+ for r in mets:
304
+ vid = r["variant_id"]
305
+ imp = int(r["impressions"])
306
+ if imp < min_imp:
307
+ continue
308
+ ev = ev_of(r)
309
+ if best_ev <= 0.0:
310
+ continue
311
+ if ev < thresh * best_ev:
312
+ con.execute("UPDATE variants SET status=?, rejection_reason=? WHERE campaign_id=? AND variant_id=?",
313
+ ("paused", "auto_pause:low_EV", campaign_id, vid))
314
+ paused.append((vid, f"EV {ev:.6f} < {thresh:.2f} * best {best_ev:.6f}"))
315
+ return paused