Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import gradio as gr | |
| import plotly.express as px | |
| import pandas as pd | |
| from data import read_events, aggregate, append_events | |
| from bandit import EmpiricalBayesHierarchicalThompson | |
| from causal import fit_uplift_binary | |
| BANDIT = EmpiricalBayesHierarchicalThompson(min_explore=0.05, margin=0.0, n_draws=10000) | |
| REQUIRED_COLS = [ | |
| "date","medium","creative","is_control", | |
| "impressions","clicks","conversions","cost","features_json" | |
| ] | |
| # ---------- helpers ---------- | |
| def _ok(msg: str) -> str: | |
| return f"✅ {msg}" | |
| def _warn(msg: str) -> str: | |
| return f"⚠️ {msg}" | |
| def _err(msg: str) -> str: | |
| return f"❌ {msg}" | |
| # ---------- data ops ---------- | |
| def ui_refresh_tables(): | |
| df = read_events() | |
| agg = aggregate() | |
| # 3つ目の出力(status)は Markdown なので、単なる文字列でOK | |
| return df, agg, _ok("Refreshed.") | |
| def ui_load_sample(): | |
| rows = [ | |
| {"date":"2025-09-01","medium":"FB","creative":"A1","is_control":1,"impressions":1000,"clicks":25,"conversions":3,"cost":480,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"FB","creative":"B1","is_control":0,"impressions":850,"clicks":24,"conversions":2,"cost":395,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"FB","creative":"C1","is_control":0,"impressions":900,"clicks":21,"conversions":2,"cost":420,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"FB","creative":"D1","is_control":0,"impressions":820,"clicks":20,"conversions":2,"cost":380,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"FB","creative":"E1","is_control":0,"impressions":960,"clicks":31,"conversions":3,"cost":490,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"GDN","creative":"A2","is_control":1,"impressions":1100,"clicks":25,"conversions":2,"cost":545,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"GDN","creative":"B2","is_control":0,"impressions":990,"clicks":27,"conversions":3,"cost":514,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"GDN","creative":"C2","is_control":0,"impressions":860,"clicks":19,"conversions":2,"cost":450,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"GDN","creative":"D2","is_control":0,"impressions":1045,"clicks":33,"conversions":4,"cost":570,"features_json":"{}"}, | |
| {"date":"2025-09-01","medium":"GDN","creative":"E2","is_control":0,"impressions":905,"clicks":20,"conversions":2,"cost":462,"features_json":"{}"} | |
| ] | |
| append_events(pd.DataFrame(rows)) | |
| df, agg = read_events(), aggregate() | |
| return df, agg, _ok("Sample data loaded.") | |
| def ui_upload_csv(file: gr.File): | |
| # アウトプットは [grid, grid_agg, status] | |
| if file is None: | |
| # 表は変更しない → gr.update() で“no change”、statusだけ文字列更新 | |
| return gr.update(), gr.update(), _warn("CSVファイルを選択してください。") | |
| try: | |
| # 文字コードの揺れに対処 | |
| try: | |
| df = pd.read_csv(file.name) | |
| except Exception: | |
| df = pd.read_csv(file.name, encoding="utf-8-sig") | |
| except Exception as e: | |
| return gr.update(), gr.update(), _err(f"CSV読み込みに失敗: {e}") | |
| # 区切り子の誤検知(TSV等)を簡易補正 | |
| if df.shape[1] == 1 and "," in str(df.columns[0]): | |
| try: | |
| df = pd.read_csv(file.name, sep=",", engine="python") | |
| except Exception: | |
| pass | |
| missing = [c for c in REQUIRED_COLS if c not in df.columns] | |
| if missing: | |
| return gr.update(), gr.update(), _err(f"必須列が不足: {missing}") | |
| # 軽いサニタイズ | |
| for c in ["is_control","impressions","clicks","conversions"]: | |
| df[c] = df[c].fillna(0).astype(int) | |
| df["cost"] = df.get("cost", 0.0) | |
| df["cost"] = df["cost"].fillna(0.0).astype(float) | |
| df["features_json"] = df.get("features_json", "{}").fillna("{}").astype(str) | |
| append_events(df) | |
| df2, agg2 = read_events(), aggregate() | |
| return df2, agg2, _ok(f"取り込み完了({len(df)}行)。") | |
| # ---------- bandit / causal ---------- | |
| def ui_recommend(): | |
| agg = aggregate() | |
| if agg.empty: | |
| return {"message": "No data yet. Dataタブで 'Load Sample Data' を押すか、CSVを取り込んでください。"} | |
| return BANDIT.recommend(agg) | |
| def ui_plot_posteriors(medium: str): | |
| agg = aggregate() | |
| if agg.empty: | |
| return gr.update(visible=False), "No data" | |
| g = agg[agg["medium"].astype(str) == str(medium)].copy() | |
| if g.empty: | |
| return gr.update(visible=False), f"No data for medium={medium}" | |
| g["ctr"] = (g["clicks"] + 1) / (g["impressions"] + 2) | |
| fig = px.bar(g, x="creative", y="ctr", color="is_control", barmode="group", | |
| title=f"CTR (Laplace) by creative @ {medium}") | |
| return gr.Plot(fig), "" | |
| def ui_fit_uplift(): | |
| agg = aggregate() | |
| if agg.empty: | |
| return {"message": "No data"} | |
| return fit_uplift_binary(agg) | |
| # ---------- UI ---------- | |
| def build_ui(): | |
| with gr.Blocks(title="AdCopy MAB Optimizer Pro") as demo: | |
| gr.Markdown("# AdCopy MAB Optimizer Pro — Hierarchical TS + Uplift") | |
| with gr.Tab("Data"): | |
| status = gr.Markdown("") # ← ここは“文字列”を受け取れる | |
| with gr.Row(): | |
| btn_refresh = gr.Button("Refresh") | |
| btn_seed = gr.Button("Load Sample Data") | |
| with gr.Row(): | |
| file = gr.File(label="Upload CSV (columns: " + ", ".join(REQUIRED_COLS) + ")", file_types=[".csv"]) | |
| btn_upload = gr.Button("Import CSV") | |
| grid = gr.Dataframe(headers=["ts","date","medium","creative","is_control","impressions","clicks","conversions","cost","features_json"], wrap=True) | |
| grid_agg = gr.Dataframe() | |
| btn_refresh.click(ui_refresh_tables, outputs=[grid, grid_agg, status]) | |
| btn_seed.click(ui_load_sample, outputs=[grid, grid_agg, status]) | |
| btn_upload.click(ui_upload_csv, inputs=[file], outputs=[grid, grid_agg, status]) | |
| with gr.Tab("Bandit"): | |
| bbtn = gr.Button("Suggest Allocation (TS)") | |
| jout = gr.JSON() | |
| bbtn.click(ui_recommend, outputs=jout) | |
| with gr.Row(): | |
| medium = gr.Textbox(label="Medium for Plot", value="FB") | |
| plot = gr.Plot(visible=False) | |
| msg = gr.Markdown() | |
| gr.Button("Plot CTR by Creative").click(ui_plot_posteriors, inputs=[medium], outputs=[plot, msg]) | |
| with gr.Tab("Uplift (Causal)"): | |
| cbtn = gr.Button("Fit Uplift Model") | |
| cout = gr.JSON() | |
| cbtn.click(ui_fit_uplift, outputs=cout) | |
| return demo | |