Corin1998's picture
Update app/storage.py
e691c59 verified
from __future__ import annotations
import os
import sqlite3
import json
from pathlib import Path
from typing import Optional, Dict, Any, List
# 永続領域 /data が書き込み可なら /data/app_data を使う。なければ /tmp/app_data へフォールバック。
DEFAULT_DIR = "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data"
DB_DIR = Path(os.environ.get("APP_DATA_DIR", DEFAULT_DIR))
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_PATH = DB_DIR / "data.db"
SCHEMA_SQL = """
PRAGMA journal_mode=WAL;
CREATE TABLE IF NOT EXISTS campaigns (
campaign_id TEXT PRIMARY KEY,
brand TEXT,
product TEXT,
target_audience TEXT,
tone TEXT,
language TEXT,
constraints_json TEXT,
value_per_conversion REAL DEFAULT 1.0
);
CREATE TABLE IF NOT EXISTS variants (
campaign_id TEXT,
variant_id TEXT,
text TEXT,
status TEXT,
rejection_reason TEXT,
PRIMARY KEY (campaign_id, variant_id)
);
CREATE TABLE IF NOT EXISTS metrics (
campaign_id TEXT,
variant_id TEXT,
impressions INTEGER DEFAULT 0,
clicks INTEGER DEFAULT 0,
conversions INTEGER DEFAULT 0,
alpha_click REAL DEFAULT 1.0,
beta_click REAL DEFAULT 1.0,
alpha_conv REAL DEFAULT 1.0,
beta_conv REAL DEFAULT 1.0,
PRIMARY KEY (campaign_id, variant_id)
);
CREATE TABLE IF NOT EXISTS events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
campaign_id TEXT,
variant_id TEXT,
event_type TEXT,
ts TEXT,
value REAL
);
"""
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def init_db():
with get_conn() as con:
con.executescript(SCHEMA_SQL)
def upsert_campaign(
campaign_id: str,
brand: str,
product: str,
target_audience: str,
tone: str,
language: str,
constraints: Optional[Dict[str, Any]],
value_per_conversion: float,
):
with get_conn() as con:
con.execute(
"""
INSERT INTO campaigns (campaign_id, brand, product, target_audience, tone, language, constraints_json, value_per_conversion)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(campaign_id) DO UPDATE SET
brand=excluded.brand,
product=excluded.product,
target_audience=excluded.target_audience,
tone=excluded.tone,
language=excluded.language,
constraints_json=excluded.constraints_json,
value_per_conversion=excluded.value_per_conversion
""",
(campaign_id, brand, product, target_audience, tone, language, json.dumps(constraints or {}), value_per_conversion),
)
def insert_variant(campaign_id: str, variant_id: str, text: str, status: str, rejection_reason: Optional[str]):
with get_conn() as con:
con.execute(
"""
INSERT OR REPLACE INTO variants (campaign_id, variant_id, text, status, rejection_reason)
VALUES (?, ?, ?, ?, ?)
""",
(campaign_id, variant_id, text, status, rejection_reason),
)
con.execute(
"""
INSERT OR IGNORE INTO metrics (campaign_id, variant_id)
VALUES (?, ?)
""",
(campaign_id, variant_id),
)
def get_variants(campaign_id: str) -> List[sqlite3.Row]:
with get_conn() as con:
cur = con.execute("SELECT * FROM variants WHERE campaign_id=?", (campaign_id,))
return cur.fetchall()
def get_variant(campaign_id: str, variant_id: str) -> Optional[sqlite3.Row]:
with get_conn() as con:
cur = con.execute("SELECT * FROM variants WHERE campaign_id=? AND variant_id=?", (campaign_id, variant_id))
return cur.fetchone()
def get_metrics(campaign_id: str) -> List[sqlite3.Row]:
with get_conn() as con:
cur = con.execute("SELECT * FROM metrics WHERE campaign_id=?", (campaign_id,))
return cur.fetchall()
def update_metric(campaign_id: str, variant_id: str, field: str, inc: float = 1.0):
assert field in {"impressions", "clicks", "conversions", "alpha_click", "beta_click", "alpha_conv", "beta_conv"}
with get_conn() as con:
con.execute(f"UPDATE metrics SET {field} = {field} + ? WHERE campaign_id=? AND variant_id=?", (inc, campaign_id, variant_id))
def log_event(campaign_id: str, variant_id: str, event_type: str, ts: str, value):
with get_conn() as con:
con.execute(
"INSERT INTO events (campaign_id, variant_id, event_type, ts, value) VALUES (?, ?, ?, ?, ?)",
(campaign_id, variant_id, event_type, ts, value),
)
def get_campaign_value_per_conversion(campaign_id: str) -> float:
with get_conn() as con:
cur = con.execute("SELECT value_per_conversion FROM campaigns WHERE campaign_id=?", (campaign_id,))
row = cur.fetchone()
return float(row[0]) if row else 1.0