AdCopy_MAB_OptimizerPro / dashboard.py
Corin1998's picture
Update dashboard.py
46cda09 verified
from __future__ import annotations
import gradio as gr
import plotly.express as px
import pandas as pd
from data import read_events, aggregate, append_events
from bandit import EmpiricalBayesHierarchicalThompson
from causal import fit_uplift_binary
BANDIT = EmpiricalBayesHierarchicalThompson(min_explore=0.05, margin=0.0, n_draws=10000)
REQUIRED_COLS = [
"date","medium","creative","is_control",
"impressions","clicks","conversions","cost","features_json"
]
# ---------- helpers ----------
def _ok(msg: str) -> str:
return f"✅ {msg}"
def _warn(msg: str) -> str:
return f"⚠️ {msg}"
def _err(msg: str) -> str:
return f"❌ {msg}"
# ---------- data ops ----------
def ui_refresh_tables():
df = read_events()
agg = aggregate()
# 3つ目の出力(status)は Markdown なので、単なる文字列でOK
return df, agg, _ok("Refreshed.")
def ui_load_sample():
rows = [
{"date":"2025-09-01","medium":"FB","creative":"A1","is_control":1,"impressions":1000,"clicks":25,"conversions":3,"cost":480,"features_json":"{}"},
{"date":"2025-09-01","medium":"FB","creative":"B1","is_control":0,"impressions":850,"clicks":24,"conversions":2,"cost":395,"features_json":"{}"},
{"date":"2025-09-01","medium":"FB","creative":"C1","is_control":0,"impressions":900,"clicks":21,"conversions":2,"cost":420,"features_json":"{}"},
{"date":"2025-09-01","medium":"FB","creative":"D1","is_control":0,"impressions":820,"clicks":20,"conversions":2,"cost":380,"features_json":"{}"},
{"date":"2025-09-01","medium":"FB","creative":"E1","is_control":0,"impressions":960,"clicks":31,"conversions":3,"cost":490,"features_json":"{}"},
{"date":"2025-09-01","medium":"GDN","creative":"A2","is_control":1,"impressions":1100,"clicks":25,"conversions":2,"cost":545,"features_json":"{}"},
{"date":"2025-09-01","medium":"GDN","creative":"B2","is_control":0,"impressions":990,"clicks":27,"conversions":3,"cost":514,"features_json":"{}"},
{"date":"2025-09-01","medium":"GDN","creative":"C2","is_control":0,"impressions":860,"clicks":19,"conversions":2,"cost":450,"features_json":"{}"},
{"date":"2025-09-01","medium":"GDN","creative":"D2","is_control":0,"impressions":1045,"clicks":33,"conversions":4,"cost":570,"features_json":"{}"},
{"date":"2025-09-01","medium":"GDN","creative":"E2","is_control":0,"impressions":905,"clicks":20,"conversions":2,"cost":462,"features_json":"{}"}
]
append_events(pd.DataFrame(rows))
df, agg = read_events(), aggregate()
return df, agg, _ok("Sample data loaded.")
def ui_upload_csv(file: gr.File):
# アウトプットは [grid, grid_agg, status]
if file is None:
# 表は変更しない → gr.update() で“no change”、statusだけ文字列更新
return gr.update(), gr.update(), _warn("CSVファイルを選択してください。")
try:
# 文字コードの揺れに対処
try:
df = pd.read_csv(file.name)
except Exception:
df = pd.read_csv(file.name, encoding="utf-8-sig")
except Exception as e:
return gr.update(), gr.update(), _err(f"CSV読み込みに失敗: {e}")
# 区切り子の誤検知(TSV等)を簡易補正
if df.shape[1] == 1 and "," in str(df.columns[0]):
try:
df = pd.read_csv(file.name, sep=",", engine="python")
except Exception:
pass
missing = [c for c in REQUIRED_COLS if c not in df.columns]
if missing:
return gr.update(), gr.update(), _err(f"必須列が不足: {missing}")
# 軽いサニタイズ
for c in ["is_control","impressions","clicks","conversions"]:
df[c] = df[c].fillna(0).astype(int)
df["cost"] = df.get("cost", 0.0)
df["cost"] = df["cost"].fillna(0.0).astype(float)
df["features_json"] = df.get("features_json", "{}").fillna("{}").astype(str)
append_events(df)
df2, agg2 = read_events(), aggregate()
return df2, agg2, _ok(f"取り込み完了({len(df)}行)。")
# ---------- bandit / causal ----------
def ui_recommend():
agg = aggregate()
if agg.empty:
return {"message": "No data yet. Dataタブで 'Load Sample Data' を押すか、CSVを取り込んでください。"}
return BANDIT.recommend(agg)
def ui_plot_posteriors(medium: str):
agg = aggregate()
if agg.empty:
return gr.update(visible=False), "No data"
g = agg[agg["medium"].astype(str) == str(medium)].copy()
if g.empty:
return gr.update(visible=False), f"No data for medium={medium}"
g["ctr"] = (g["clicks"] + 1) / (g["impressions"] + 2)
fig = px.bar(g, x="creative", y="ctr", color="is_control", barmode="group",
title=f"CTR (Laplace) by creative @ {medium}")
return gr.Plot(fig), ""
def ui_fit_uplift():
agg = aggregate()
if agg.empty:
return {"message": "No data"}
return fit_uplift_binary(agg)
# ---------- UI ----------
def build_ui():
with gr.Blocks(title="AdCopy MAB Optimizer Pro") as demo:
gr.Markdown("# AdCopy MAB Optimizer Pro — Hierarchical TS + Uplift")
with gr.Tab("Data"):
status = gr.Markdown("") # ← ここは“文字列”を受け取れる
with gr.Row():
btn_refresh = gr.Button("Refresh")
btn_seed = gr.Button("Load Sample Data")
with gr.Row():
file = gr.File(label="Upload CSV (columns: " + ", ".join(REQUIRED_COLS) + ")", file_types=[".csv"])
btn_upload = gr.Button("Import CSV")
grid = gr.Dataframe(headers=["ts","date","medium","creative","is_control","impressions","clicks","conversions","cost","features_json"], wrap=True)
grid_agg = gr.Dataframe()
btn_refresh.click(ui_refresh_tables, outputs=[grid, grid_agg, status])
btn_seed.click(ui_load_sample, outputs=[grid, grid_agg, status])
btn_upload.click(ui_upload_csv, inputs=[file], outputs=[grid, grid_agg, status])
with gr.Tab("Bandit"):
bbtn = gr.Button("Suggest Allocation (TS)")
jout = gr.JSON()
bbtn.click(ui_recommend, outputs=jout)
with gr.Row():
medium = gr.Textbox(label="Medium for Plot", value="FB")
plot = gr.Plot(visible=False)
msg = gr.Markdown()
gr.Button("Plot CTR by Creative").click(ui_plot_posteriors, inputs=[medium], outputs=[plot, msg])
with gr.Tab("Uplift (Causal)"):
cbtn = gr.Button("Fit Uplift Model")
cout = gr.JSON()
cbtn.click(ui_fit_uplift, outputs=cout)
return demo