Corin1998's picture
Update app/storage.py
d5be10a verified
from __future__ import annotations
import os, csv, json, sqlite3
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
DEFAULT_DIR = "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data"
DB_DIR = Path(os.environ.get("APP_DATA_DIR", DEFAULT_DIR))
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_DIR / "data.db"
SCHEMA_SQL = """
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS campaigns (
campaign_id TEXT PRIMARY KEY,
brand TEXT,
product TEXT,
target_audience TEXT,
tone TEXT,
language TEXT,
constraints_json TEXT,
value_per_conversion REAL DEFAULT 1.0,
policy TEXT DEFAULT 'thompson',
holdout_ratio REAL DEFAULT 0.0,
stop_min_impressions INTEGER DEFAULT 200,
stop_rel_ev_threshold REAL DEFAULT 0.5,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS variants (
campaign_id TEXT,
variant_id TEXT,
text TEXT,
status TEXT,
rejection_reason TEXT,
PRIMARY KEY (campaign_id, variant_id)
);
CREATE TABLE IF NOT EXISTS metrics (
campaign_id TEXT,
variant_id TEXT,
impressions INTEGER DEFAULT 0,
clicks INTEGER DEFAULT 0,
conversions INTEGER DEFAULT 0,
alpha_click REAL DEFAULT 1.0,
beta_click REAL DEFAULT 1.0,
alpha_conv REAL DEFAULT 1.0,
beta_conv REAL DEFAULT 1.0,
PRIMARY KEY (campaign_id, variant_id)
);
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
campaign_id TEXT,
variant_id TEXT,
event_type TEXT,
ts TEXT,
value REAL
);
-- LinUCB のパラメータをJSONで保持
CREATE TABLE IF NOT EXISTS linucb (
campaign_id TEXT,
variant_id TEXT,
d INTEGER,
A_json TEXT,
b_json TEXT,
n_updates INTEGER DEFAULT 0,
PRIMARY KEY (campaign_id, variant_id)
);
-- コンプライアンス監査ログ(NG詳細/LLM修正案)
CREATE TABLE IF NOT EXISTS compliance_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
campaign_id TEXT,
variant_id TEXT,
status TEXT,
ng_rules_json TEXT,
llm_ok INTEGER,
llm_reasons_json TEXT,
llm_fixed TEXT,
ts TEXT DEFAULT (datetime('now'))
);
-- 任意の運用監査ログ
CREATE TABLE IF NOT EXISTS audit_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
campaign_id TEXT,
action TEXT,
payload_json TEXT,
ts TEXT DEFAULT (datetime('now'))
);
"""
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def _ensure_columns():
need_cols = {
"campaigns": [
("policy", "TEXT", "'thompson'"),
("holdout_ratio", "REAL", "0.0"),
("stop_min_impressions", "INTEGER", "200"),
("stop_rel_ev_threshold", "REAL", "0.5"),
("created_at", "TEXT", "datetime('now')"),
]
}
with get_conn() as con:
for table, cols in need_cols.items():
cur = con.execute(f"PRAGMA table_info({table})")
have = {r["name"] for r in cur.fetchall()}
for name, typ, default in cols:
if name not in have:
con.execute(f"ALTER TABLE {table} ADD COLUMN {name} {typ} DEFAULT ({default})")
def init_db():
with get_conn() as con:
con.executescript(SCHEMA_SQL)
_ensure_columns()
# ============== Campaign/Variant/Metric 基本 ==============
def upsert_campaign(campaign_id: str, brand: str, product: str, target_audience: str,
tone: str, language: str, constraints: Optional[Dict[str, Any]],
value_per_conversion: float):
with get_conn() as con:
con.execute(
"""
INSERT INTO campaigns (campaign_id, brand, product, target_audience, tone, language, constraints_json, value_per_conversion)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(campaign_id) DO UPDATE SET
brand=excluded.brand,
product=excluded.product,
target_audience=excluded.target_audience,
tone=excluded.tone,
language=excluded.language,
constraints_json=excluded.constraints_json,
value_per_conversion=excluded.value_per_conversion
""",
(campaign_id, brand, product, target_audience, tone, language, json.dumps(constraints or {}, ensure_ascii=False), value_per_conversion),
)
def set_campaign_settings(campaign_id: str, policy: str, holdout_ratio: float, stop_min_impressions: int, stop_rel_ev_threshold: float):
with get_conn() as con:
con.execute(
"""
UPDATE campaigns SET policy=?, holdout_ratio=?, stop_min_impressions=?, stop_rel_ev_threshold=?
WHERE campaign_id=?
""",
(policy, float(holdout_ratio), int(stop_min_impressions), float(stop_rel_ev_threshold), campaign_id)
)
def get_campaign(campaign_id: str):
with get_conn() as con:
cur = con.execute("SELECT * FROM campaigns WHERE campaign_id=?", (campaign_id,))
return cur.fetchone()
def insert_variant(campaign_id: str, variant_id: str, text: str, status: str, rejection_reason: Optional[str]):
with get_conn() as con:
con.execute(
"""
INSERT OR REPLACE INTO variants (campaign_id, variant_id, text, status, rejection_reason)
VALUES (?, ?, ?, ?, ?)
""",
(campaign_id, variant_id, text, status, rejection_reason),
)
con.execute(
"INSERT OR IGNORE INTO metrics (campaign_id, variant_id) VALUES (?, ?)",
(campaign_id, variant_id)
)
def get_variant(campaign_id: str, variant_id: str):
with get_conn() as con:
cur = con.execute("SELECT * FROM variants WHERE campaign_id=? AND variant_id=?", (campaign_id, variant_id))
return cur.fetchone()
def get_variants(campaign_id: str) -> List[sqlite3.Row]:
with get_conn() as con:
cur = con.execute("SELECT * FROM variants WHERE campaign_id=?", (campaign_id,))
return cur.fetchall()
def get_metrics(campaign_id: str) -> List[sqlite3.Row]:
with get_conn() as con:
cur = con.execute("SELECT * FROM metrics WHERE campaign_id=?", (campaign_id,))
return cur.fetchall()
def update_metric(campaign_id: str, variant_id: str, field: str, inc: float = 1.0):
assert field in {"impressions", "clicks", "conversions", "alpha_click", "beta_click", "alpha_conv", "beta_conv"}
with get_conn() as con:
con.execute(f"UPDATE metrics SET {field} = {field} + ? WHERE campaign_id=? AND variant_id=?", (inc, campaign_id, variant_id))
def log_event(campaign_id: str, variant_id: str, event_type: str, ts: str, value):
with get_conn() as con:
con.execute(
"INSERT INTO events (campaign_id, variant_id, event_type, ts, value) VALUES (?, ?, ?, ?, ?)",
(campaign_id, variant_id, event_type, ts, value)
)
def get_campaign_value_per_conversion(campaign_id: str) -> float:
with get_conn() as con:
cur = con.execute("SELECT value_per_conversion FROM campaigns WHERE campaign_id=?", (campaign_id,))
row = cur.fetchone()
return float(row[0]) if row else 1.0
# ============== Compliance / Audit ==============
def record_compliance_log(campaign_id: str, variant_id: str, status: str,
ng_rules: List[str], llm_ok: bool, llm_reasons: List[str], llm_fixed: Optional[str]):
with get_conn() as con:
con.execute(
"""
INSERT INTO compliance_logs (campaign_id, variant_id, status, ng_rules_json, llm_ok, llm_reasons_json, llm_fixed)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(campaign_id, variant_id, status, json.dumps(ng_rules, ensure_ascii=False), int(llm_ok),
json.dumps(llm_reasons, ensure_ascii=False), llm_fixed)
)
def audit(campaign_id: str, action: str, payload: Dict[str, Any]):
with get_conn() as con:
con.execute(
"INSERT INTO audit_logs (campaign_id, action, payload_json) VALUES (?, ?, ?)",
(campaign_id, action, json.dumps(payload, ensure_ascii=False))
)
# ============== LinUCB state ==============
def get_linucb_state(campaign_id: str, variant_id: str):
with get_conn() as con:
cur = con.execute("SELECT d, A_json, b_json, n_updates FROM linucb WHERE campaign_id=? AND variant_id=?",
(campaign_id, variant_id))
row = cur.fetchone()
return row
def upsert_linucb_state(campaign_id: str, variant_id: str, d: int, A_json: str, b_json: str, n_updates: int):
with get_conn() as con:
con.execute(
"""
INSERT INTO linucb (campaign_id, variant_id, d, A_json, b_json, n_updates)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(campaign_id, variant_id) DO UPDATE SET
d=excluded.d, A_json=excluded.A_json, b_json=excluded.b_json, n_updates=excluded.n_updates
""",
(campaign_id, variant_id, d, A_json, b_json, n_updates)
)
# ============== Export / Reset / Stop rules ==============
def export_csv(campaign_id: str, table: str) -> str:
assert table in {"events", "metrics", "variants", "compliance_logs", "audit_logs"}
out_dir = DB_DIR / "export"
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"{campaign_id}_{table}.csv"
with get_conn() as con, open(out_path, "w", newline="", encoding="utf-8") as f:
cur = con.execute(f"SELECT * FROM {table} WHERE campaign_id=? ORDER BY rowid ASC", (campaign_id,))
rows = cur.fetchall()
if not rows:
f.write("") # 空でもファイルは作る
return str(out_path)
fieldnames = rows[0].keys()
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow({k: r[k] for k in fieldnames})
return str(out_path)
def reset_all():
# 破壊的操作:全テーブル初期化
with get_conn() as con:
con.executescript("""
DROP TABLE IF EXISTS linucb;
DROP TABLE IF EXISTS compliance_logs;
DROP TABLE IF EXISTS audit_logs;
DROP TABLE IF EXISTS events;
DROP TABLE IF EXISTS metrics;
DROP TABLE IF EXISTS variants;
DROP TABLE IF EXISTS campaigns;
""")
init_db()
def evaluate_stop_rules(campaign_id: str) -> List[Tuple[str, str]]:
"""
撤退基準:
- impressions >= stop_min_impressions
- EV(CTRmean*CVRmean*V) がベストの stop_rel_ev_threshold 倍未満 → pause
返り値: [(variant_id, reason), ...] (pause されたもの)
"""
cfg = get_campaign(campaign_id)
if not cfg:
return []
min_imp = int(cfg["stop_min_impressions"] or 200)
thresh = float(cfg["stop_rel_ev_threshold"] or 0.5)
vpc = float(cfg["value_per_conversion"] or 1.0)
mets = get_metrics(campaign_id)
if not mets:
return []
def ev_of(r):
imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"])
ctr = (clk / imp) if imp > 0 else 0.0
cvr = (conv / max(1, clk)) if clk > 0 else 0.0
return ctr * cvr * vpc
best_ev = max((ev_of(r) for r in mets), default=0.0)
paused = []
with get_conn() as con:
for r in mets:
vid = r["variant_id"]
imp = int(r["impressions"])
if imp < min_imp:
continue
ev = ev_of(r)
if best_ev <= 0.0:
continue
if ev < thresh * best_ev:
con.execute("UPDATE variants SET status=?, rejection_reason=? WHERE campaign_id=? AND variant_id=?",
("paused", "auto_pause:low_EV", campaign_id, vid))
paused.append((vid, f"EV {ev:.6f} < {thresh:.2f} * best {best_ev:.6f}"))
return paused