File size: 6,681 Bytes
8b4a5e6
 
 
eca36cb
 
8b4a5e6
 
 
 
 
eca36cb
 
 
 
 
46cda09
 
 
 
 
 
 
eca36cb
46cda09
8b4a5e6
 
 
46cda09
eca36cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46cda09
eca36cb
46cda09
eca36cb
 
46cda09
 
 
 
 
eca36cb
 
 
46cda09
 
 
 
 
 
 
eca36cb
 
 
 
46cda09
eca36cb
 
46cda09
 
 
eca36cb
 
 
 
8b4a5e6
46cda09
8b4a5e6
 
 
eca36cb
0735cb3
8b4a5e6
 
 
 
 
 
 
 
 
0735cb3
 
8b4a5e6
 
 
 
 
 
0735cb3
8b4a5e6
46cda09
8b4a5e6
 
 
eca36cb
8b4a5e6
46cda09
eca36cb
 
 
 
 
 
8b4a5e6
 
eca36cb
 
 
 
 
8b4a5e6
 
 
 
 
 
 
 
0735cb3
eca36cb
8b4a5e6
 
 
 
eca36cb
8b4a5e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations
import gradio as gr
import plotly.express as px
import pandas as pd
from data import read_events, aggregate, append_events
from bandit import EmpiricalBayesHierarchicalThompson
from causal import fit_uplift_binary

BANDIT = EmpiricalBayesHierarchicalThompson(min_explore=0.05, margin=0.0, n_draws=10000)

REQUIRED_COLS = [
    "date","medium","creative","is_control",
    "impressions","clicks","conversions","cost","features_json"
]

# ---------- helpers ----------
def _ok(msg: str) -> str:
    return f"✅ {msg}"
def _warn(msg: str) -> str:
    return f"⚠️ {msg}"
def _err(msg: str) -> str:
    return f"❌ {msg}"

# ---------- data ops ----------
def ui_refresh_tables():
    df = read_events()
    agg = aggregate()
    # 3つ目の出力(status)は Markdown なので、単なる文字列でOK
    return df, agg, _ok("Refreshed.")

def ui_load_sample():
    rows = [
        {"date":"2025-09-01","medium":"FB","creative":"A1","is_control":1,"impressions":1000,"clicks":25,"conversions":3,"cost":480,"features_json":"{}"},
        {"date":"2025-09-01","medium":"FB","creative":"B1","is_control":0,"impressions":850,"clicks":24,"conversions":2,"cost":395,"features_json":"{}"},
        {"date":"2025-09-01","medium":"FB","creative":"C1","is_control":0,"impressions":900,"clicks":21,"conversions":2,"cost":420,"features_json":"{}"},
        {"date":"2025-09-01","medium":"FB","creative":"D1","is_control":0,"impressions":820,"clicks":20,"conversions":2,"cost":380,"features_json":"{}"},
        {"date":"2025-09-01","medium":"FB","creative":"E1","is_control":0,"impressions":960,"clicks":31,"conversions":3,"cost":490,"features_json":"{}"},
        {"date":"2025-09-01","medium":"GDN","creative":"A2","is_control":1,"impressions":1100,"clicks":25,"conversions":2,"cost":545,"features_json":"{}"},
        {"date":"2025-09-01","medium":"GDN","creative":"B2","is_control":0,"impressions":990,"clicks":27,"conversions":3,"cost":514,"features_json":"{}"},
        {"date":"2025-09-01","medium":"GDN","creative":"C2","is_control":0,"impressions":860,"clicks":19,"conversions":2,"cost":450,"features_json":"{}"},
        {"date":"2025-09-01","medium":"GDN","creative":"D2","is_control":0,"impressions":1045,"clicks":33,"conversions":4,"cost":570,"features_json":"{}"},
        {"date":"2025-09-01","medium":"GDN","creative":"E2","is_control":0,"impressions":905,"clicks":20,"conversions":2,"cost":462,"features_json":"{}"}
    ]
    append_events(pd.DataFrame(rows))
    df, agg = read_events(), aggregate()
    return df, agg, _ok("Sample data loaded.")

def ui_upload_csv(file: gr.File):
    # アウトプットは [grid, grid_agg, status]
    if file is None:
        # 表は変更しない → gr.update() で“no change”、statusだけ文字列更新
        return gr.update(), gr.update(), _warn("CSVファイルを選択してください。")
    try:
        # 文字コードの揺れに対処
        try:
            df = pd.read_csv(file.name)
        except Exception:
            df = pd.read_csv(file.name, encoding="utf-8-sig")
    except Exception as e:
        return gr.update(), gr.update(), _err(f"CSV読み込みに失敗: {e}")

    # 区切り子の誤検知(TSV等)を簡易補正
    if df.shape[1] == 1 and "," in str(df.columns[0]):
        try:
            df = pd.read_csv(file.name, sep=",", engine="python")
        except Exception:
            pass

    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        return gr.update(), gr.update(), _err(f"必須列が不足: {missing}")

    # 軽いサニタイズ
    for c in ["is_control","impressions","clicks","conversions"]:
        df[c] = df[c].fillna(0).astype(int)
    df["cost"] = df.get("cost", 0.0)
    df["cost"] = df["cost"].fillna(0.0).astype(float)
    df["features_json"] = df.get("features_json", "{}").fillna("{}").astype(str)

    append_events(df)
    df2, agg2 = read_events(), aggregate()
    return df2, agg2, _ok(f"取り込み完了({len(df)}行)。")

# ---------- bandit / causal ----------
def ui_recommend():
    agg = aggregate()
    if agg.empty:
        return {"message": "No data yet. Dataタブで 'Load Sample Data' を押すか、CSVを取り込んでください。"}
    return BANDIT.recommend(agg)

def ui_plot_posteriors(medium: str):
    agg = aggregate()
    if agg.empty:
        return gr.update(visible=False), "No data"
    g = agg[agg["medium"].astype(str) == str(medium)].copy()
    if g.empty:
        return gr.update(visible=False), f"No data for medium={medium}"
    g["ctr"] = (g["clicks"] + 1) / (g["impressions"] + 2)
    fig = px.bar(g, x="creative", y="ctr", color="is_control", barmode="group",
                 title=f"CTR (Laplace) by creative @ {medium}")
    return gr.Plot(fig), ""

def ui_fit_uplift():
    agg = aggregate()
    if agg.empty:
        return {"message": "No data"}
    return fit_uplift_binary(agg)

# ---------- UI ----------
def build_ui():
    with gr.Blocks(title="AdCopy MAB Optimizer Pro") as demo:
        gr.Markdown("# AdCopy MAB Optimizer Pro — Hierarchical TS + Uplift")

        with gr.Tab("Data"):
            status = gr.Markdown("")  # ← ここは“文字列”を受け取れる
            with gr.Row():
                btn_refresh = gr.Button("Refresh")
                btn_seed = gr.Button("Load Sample Data")
            with gr.Row():
                file = gr.File(label="Upload CSV (columns: " + ", ".join(REQUIRED_COLS) + ")", file_types=[".csv"])
                btn_upload = gr.Button("Import CSV")
            grid = gr.Dataframe(headers=["ts","date","medium","creative","is_control","impressions","clicks","conversions","cost","features_json"], wrap=True)
            grid_agg = gr.Dataframe()

            btn_refresh.click(ui_refresh_tables, outputs=[grid, grid_agg, status])
            btn_seed.click(ui_load_sample, outputs=[grid, grid_agg, status])
            btn_upload.click(ui_upload_csv, inputs=[file], outputs=[grid, grid_agg, status])

        with gr.Tab("Bandit"):
            bbtn = gr.Button("Suggest Allocation (TS)")
            jout = gr.JSON()
            bbtn.click(ui_recommend, outputs=jout)
            with gr.Row():
                medium = gr.Textbox(label="Medium for Plot", value="FB")
                plot = gr.Plot(visible=False)
                msg = gr.Markdown()
            gr.Button("Plot CTR by Creative").click(ui_plot_posteriors, inputs=[medium], outputs=[plot, msg])

        with gr.Tab("Uplift (Causal)"):
            cbtn = gr.Button("Fit Uplift Model")
            cout = gr.JSON()
            cbtn.click(ui_fit_uplift, outputs=cout)

    return demo