from __future__ import annotations import gradio as gr import plotly.express as px import pandas as pd from data import read_events, aggregate, append_events from bandit import EmpiricalBayesHierarchicalThompson from causal import fit_uplift_binary BANDIT = EmpiricalBayesHierarchicalThompson(min_explore=0.05, margin=0.0, n_draws=10000) REQUIRED_COLS = [ "date","medium","creative","is_control", "impressions","clicks","conversions","cost","features_json" ] # ---------- helpers ---------- def _ok(msg: str) -> str: return f"✅ {msg}" def _warn(msg: str) -> str: return f"⚠️ {msg}" def _err(msg: str) -> str: return f"❌ {msg}" # ---------- data ops ---------- def ui_refresh_tables(): df = read_events() agg = aggregate() # 3つ目の出力(status)は Markdown なので、単なる文字列でOK return df, agg, _ok("Refreshed.") def ui_load_sample(): rows = [ {"date":"2025-09-01","medium":"FB","creative":"A1","is_control":1,"impressions":1000,"clicks":25,"conversions":3,"cost":480,"features_json":"{}"}, {"date":"2025-09-01","medium":"FB","creative":"B1","is_control":0,"impressions":850,"clicks":24,"conversions":2,"cost":395,"features_json":"{}"}, {"date":"2025-09-01","medium":"FB","creative":"C1","is_control":0,"impressions":900,"clicks":21,"conversions":2,"cost":420,"features_json":"{}"}, {"date":"2025-09-01","medium":"FB","creative":"D1","is_control":0,"impressions":820,"clicks":20,"conversions":2,"cost":380,"features_json":"{}"}, {"date":"2025-09-01","medium":"FB","creative":"E1","is_control":0,"impressions":960,"clicks":31,"conversions":3,"cost":490,"features_json":"{}"}, {"date":"2025-09-01","medium":"GDN","creative":"A2","is_control":1,"impressions":1100,"clicks":25,"conversions":2,"cost":545,"features_json":"{}"}, {"date":"2025-09-01","medium":"GDN","creative":"B2","is_control":0,"impressions":990,"clicks":27,"conversions":3,"cost":514,"features_json":"{}"}, {"date":"2025-09-01","medium":"GDN","creative":"C2","is_control":0,"impressions":860,"clicks":19,"conversions":2,"cost":450,"features_json":"{}"}, {"date":"2025-09-01","medium":"GDN","creative":"D2","is_control":0,"impressions":1045,"clicks":33,"conversions":4,"cost":570,"features_json":"{}"}, {"date":"2025-09-01","medium":"GDN","creative":"E2","is_control":0,"impressions":905,"clicks":20,"conversions":2,"cost":462,"features_json":"{}"} ] append_events(pd.DataFrame(rows)) df, agg = read_events(), aggregate() return df, agg, _ok("Sample data loaded.") def ui_upload_csv(file: gr.File): # アウトプットは [grid, grid_agg, status] if file is None: # 表は変更しない → gr.update() で“no change”、statusだけ文字列更新 return gr.update(), gr.update(), _warn("CSVファイルを選択してください。") try: # 文字コードの揺れに対処 try: df = pd.read_csv(file.name) except Exception: df = pd.read_csv(file.name, encoding="utf-8-sig") except Exception as e: return gr.update(), gr.update(), _err(f"CSV読み込みに失敗: {e}") # 区切り子の誤検知(TSV等)を簡易補正 if df.shape[1] == 1 and "," in str(df.columns[0]): try: df = pd.read_csv(file.name, sep=",", engine="python") except Exception: pass missing = [c for c in REQUIRED_COLS if c not in df.columns] if missing: return gr.update(), gr.update(), _err(f"必須列が不足: {missing}") # 軽いサニタイズ for c in ["is_control","impressions","clicks","conversions"]: df[c] = df[c].fillna(0).astype(int) df["cost"] = df.get("cost", 0.0) df["cost"] = df["cost"].fillna(0.0).astype(float) df["features_json"] = df.get("features_json", "{}").fillna("{}").astype(str) append_events(df) df2, agg2 = read_events(), aggregate() return df2, agg2, _ok(f"取り込み完了({len(df)}行)。") # ---------- bandit / causal ---------- def ui_recommend(): agg = aggregate() if agg.empty: return {"message": "No data yet. Dataタブで 'Load Sample Data' を押すか、CSVを取り込んでください。"} return BANDIT.recommend(agg) def ui_plot_posteriors(medium: str): agg = aggregate() if agg.empty: return gr.update(visible=False), "No data" g = agg[agg["medium"].astype(str) == str(medium)].copy() if g.empty: return gr.update(visible=False), f"No data for medium={medium}" g["ctr"] = (g["clicks"] + 1) / (g["impressions"] + 2) fig = px.bar(g, x="creative", y="ctr", color="is_control", barmode="group", title=f"CTR (Laplace) by creative @ {medium}") return gr.Plot(fig), "" def ui_fit_uplift(): agg = aggregate() if agg.empty: return {"message": "No data"} return fit_uplift_binary(agg) # ---------- UI ---------- def build_ui(): with gr.Blocks(title="AdCopy MAB Optimizer Pro") as demo: gr.Markdown("# AdCopy MAB Optimizer Pro — Hierarchical TS + Uplift") with gr.Tab("Data"): status = gr.Markdown("") # ← ここは“文字列”を受け取れる with gr.Row(): btn_refresh = gr.Button("Refresh") btn_seed = gr.Button("Load Sample Data") with gr.Row(): file = gr.File(label="Upload CSV (columns: " + ", ".join(REQUIRED_COLS) + ")", file_types=[".csv"]) btn_upload = gr.Button("Import CSV") grid = gr.Dataframe(headers=["ts","date","medium","creative","is_control","impressions","clicks","conversions","cost","features_json"], wrap=True) grid_agg = gr.Dataframe() btn_refresh.click(ui_refresh_tables, outputs=[grid, grid_agg, status]) btn_seed.click(ui_load_sample, outputs=[grid, grid_agg, status]) btn_upload.click(ui_upload_csv, inputs=[file], outputs=[grid, grid_agg, status]) with gr.Tab("Bandit"): bbtn = gr.Button("Suggest Allocation (TS)") jout = gr.JSON() bbtn.click(ui_recommend, outputs=jout) with gr.Row(): medium = gr.Textbox(label="Medium for Plot", value="FB") plot = gr.Plot(visible=False) msg = gr.Markdown() gr.Button("Plot CTR by Creative").click(ui_plot_posteriors, inputs=[medium], outputs=[plot, msg]) with gr.Tab("Uplift (Causal)"): cbtn = gr.Button("Fit Uplift Model") cout = gr.JSON() cbtn.click(ui_fit_uplift, outputs=cout) return demo