Corin1998 commited on
Commit
70db9c9
·
verified ·
1 Parent(s): 1d5d67b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +413 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Hugging Face Spaces-ready Gradio + FastAPI app
3
+ # 機能: クリエイティブ自動生成(OpenAI)、ブランドガイド従属性チェック、A/Bテスト(CTR/CVR)、多腕バンディット(ε-greedy)
4
+ # 追跡: /v/<variant> で表示、/click と /convert でクリック/コンバージョン計測。/r はローテータ(割当)URL。
5
+ # メモリ: SQLite(./ab.db)
6
+ # デプロイ: 「Spaces(Gradio)」で動作。環境変数 OPENAI_API_KEY を設定してください。
7
+
8
+ import os
9
+ import sqlite3
10
+ import json
11
+ import time
12
+ import random
13
+ import hashlib
14
+ from datetime import datetime
15
+ from typing import List, Dict, Any, Optional
16
+ from urllib.parse import quote, unquote
17
+
18
+ import gradio as gr
19
+ from fastapi import FastAPI, Request
20
+ from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
21
+ from openai import OpenAI
22
+
23
+ DB_PATH = "ab.db"
24
+ TABLE_SQL = """
25
+ CREATE TABLE IF NOT EXISTS variants (
26
+ id TEXT PRIMARY KEY,
27
+ name TEXT,
28
+ channel TEXT,
29
+ target_metric TEXT,
30
+ kpi_target REAL,
31
+ brand_guidelines TEXT,
32
+ banned_words TEXT,
33
+ product TEXT,
34
+ audience TEXT,
35
+ cta_url TEXT,
36
+ headline TEXT,
37
+ body TEXT,
38
+ cta TEXT,
39
+ image_prompt TEXT,
40
+ created_at INTEGER,
41
+ impressions INTEGER DEFAULT 0,
42
+ clicks INTEGER DEFAULT 0,
43
+ conversions INTEGER DEFAULT 0
44
+ );
45
+
46
+ CREATE TABLE IF NOT EXISTS weights (
47
+ variant_id TEXT PRIMARY KEY,
48
+ weight REAL
49
+ );
50
+ """
51
+
52
+ def db():
53
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
54
+ conn.row_factory = sqlite3.Row
55
+ return conn
56
+
57
+ def init_db():
58
+ conn = db()
59
+ cur = conn.cursor()
60
+ for stmt in TABLE_SQL.strip().split(";"):
61
+ s = stmt.strip()
62
+ if s:
63
+ cur.execute(s)
64
+ conn.commit()
65
+ conn.close()
66
+
67
+ def now_ts() -> int:
68
+ return int(time.time())
69
+
70
+ def gen_variant_id(seed: str) -> str:
71
+ return hashlib.sha256(f"{seed}-{time.time()}-{random.random()}".encode()).hexdigest()[:12]
72
+
73
+ def upsert_weight(conn, vid: str, w: float):
74
+ cur = conn.cursor()
75
+ cur.execute("INSERT INTO weights(variant_id, weight) VALUES(?, ?) ON CONFLICT(variant_id) DO UPDATE SET weight=excluded.weight", (vid, w))
76
+ conn.commit()
77
+
78
+ def get_weights(conn):
79
+ rows = conn.execute("SELECT variant_id, weight FROM weights").fetchall()
80
+ return {r["variant_id"]: r["weight"] for r in rows}
81
+
82
+ def epsilon_greedy_pick(conn, epsilon: float = 0.1) -> Optional[str]:
83
+ # CTR ベースのε-greedy
84
+ rows = conn.execute("SELECT id, impressions, clicks FROM variants").fetchall()
85
+ if not rows:
86
+ return None
87
+ # 探索
88
+ if random.random() < epsilon:
89
+ return random.choice(rows)["id"]
90
+ # 活用: CTR最大
91
+ best_vid = None
92
+ best_ctr = -1
93
+ for r in rows:
94
+ imps = r["impressions"]
95
+ clk = r["clicks"]
96
+ ctr = (clk / imps) if imps > 0 else 0.0
97
+ if ctr > best_ctr:
98
+ best_ctr = ctr
99
+ best_vid = r["id"]
100
+ return best_vid or rows[0]["id"]
101
+
102
+ def record_impression(conn, vid: str):
103
+ conn.execute("UPDATE variants SET impressions = impressions + 1 WHERE id=?", (vid,))
104
+ conn.commit()
105
+
106
+ def record_click(conn, vid: str):
107
+ conn.execute("UPDATE variants SET clicks = clicks + 1 WHERE id=?", (vid,))
108
+ conn.commit()
109
+
110
+ def record_conversion(conn, vid: str):
111
+ conn.execute("UPDATE variants SET conversions = conversions + 1 WHERE id=?", (vid,))
112
+ conn.commit()
113
+
114
+ def fetch_variant(conn, vid: str):
115
+ row = conn.execute("SELECT * FROM variants WHERE id=?", (vid,)).fetchone()
116
+ return row
117
+
118
+ def list_variants(conn):
119
+ rows = conn.execute("SELECT * FROM variants ORDER BY created_at DESC").fetchall()
120
+ return rows
121
+
122
+ def delete_all(conn):
123
+ conn.execute("DELETE FROM variants")
124
+ conn.execute("DELETE FROM weights")
125
+ conn.commit()
126
+
127
+ # ---- Brand compliance (ローカル簡易チェック) ----
128
+ def check_compliance(text: str, guidelines: str, banned_words: List[str]) -> Dict[str, Any]:
129
+ violations = []
130
+ lw = text.lower()
131
+ for w in banned_words:
132
+ ww = w.strip().lower()
133
+ if ww and ww in lw:
134
+ violations.append(f"banned_word:{w}")
135
+ # 文字数の過不足などの軽いヒューリスティック
136
+ if len(text) < 10:
137
+ violations.append("too_short")
138
+ if len(text) > 1200:
139
+ violations.append("too_long")
140
+ ok = len(violations) == 0
141
+ return {"ok": ok, "violations": violations}
142
+
143
+ # ---- OpenAIでクリエイティブ生成 ----
144
+ def openai_client():
145
+ api_key = os.getenv("OPENAI_API_KEY")
146
+ if not api_key:
147
+ raise RuntimeError("OPENAI_API_KEY が未設定です(SpacesのSecretsに設定してください)。")
148
+ return OpenAI(api_key=api_key)
149
+
150
+ GEN_SYS_PROMPT = """You are a world-class performance marketer and copywriter.
151
+ You must output STRICT JSON encoded in UTF-8 with this schema:
152
+ {
153
+ "variants": [
154
+ {
155
+ "name": "string short nickname",
156
+ "headline": "string 60 chars max",
157
+ "body": "string 280-600 chars",
158
+ "cta": "string <= 30 chars",
159
+ "image_prompt": "string concise visual prompt for an image generator"
160
+ }, ...
161
+ ]
162
+ }
163
+ Rules:
164
+ - Obey brand guidelines: tone, banned phrases, do-not-claim, legal constraints.
165
+ - Optimize for the stated goal (CTR or CVR) and KPI target.
166
+ - Avoid spammy claims, avoid ALL CAPS.
167
+ - Keep culturally appropriate Japanese copy if input is Japanese.
168
+ """
169
+
170
+ def generate_variants_with_openai(goal: str, kpi_target: float, channel: str, product: str, audience: str, brand_guidelines: str, banned_words: List[str], n_variants: int = 3) -> List[Dict[str, str]]:
171
+ client = openai_client()
172
+ user_prompt = {
173
+ "goal": goal,
174
+ "kpi_target": kpi_target,
175
+ "channel": channel,
176
+ "product": product,
177
+ "audience": audience,
178
+ "brand_guidelines": brand_guidelines,
179
+ "banned_words": banned_words,
180
+ "n_variants": n_variants
181
+ }
182
+ # Chat Completions (安定): モデルは適宜変更可
183
+ resp = client.chat.completions.create(
184
+ model="gpt-4o-mini",
185
+ temperature=0.8,
186
+ messages=[
187
+ {"role": "system", "content": GEN_SYS_PROMPT},
188
+ {"role": "user", "content": json.dumps(user_prompt, ensure_ascii=False)}
189
+ ]
190
+ )
191
+ content = resp.choices[0].message.content.strip()
192
+ # JSON抽出
193
+ try:
194
+ data = json.loads(content)
195
+ variants = data.get("variants", [])
196
+ except Exception:
197
+ # 失敗時フォールバック: JSONのコードブロックを剥がす
198
+ import re
199
+ m = re.search(r"\{[\s\S]*\}", content)
200
+ if not m:
201
+ raise RuntimeError("モデル出力のJSON解析に失敗しました。出力:\n" + content[:500])
202
+ data = json.loads(m.group(0))
203
+ variants = data.get("variants", [])
204
+ # banned_words 簡易フィルタ
205
+ filtered = []
206
+ for v in variants[:n_variants]:
207
+ body = v.get("body", "")
208
+ head = v.get("headline", "")
209
+ cta = v.get("cta", "")
210
+ comp = "\n".join([head, body, cta])
211
+ compliance = check_compliance(comp, brand_guidelines, banned_words)
212
+ if compliance["ok"]:
213
+ filtered.append(v)
214
+ else:
215
+ # 軽微な違反(長さなど)は通す。禁止語のみ除外。
216
+ if not any(vi.startswith("banned_word") for vi in compliance["violations"]):
217
+ filtered.append(v)
218
+ if not filtered:
219
+ filtered = variants[:n_variants]
220
+ return filtered[:n_variants]
221
+
222
+ # ---- FastAPI ルーティング(追跡/配信) ----
223
+ fastapi_app = FastAPI()
224
+
225
+ @fastapi_app.get("/")
226
+ def root():
227
+ return PlainTextResponse("This Space hosts the Campaign Generator + A/B tester. Visit /gradio for UI.")
228
+
229
+ @fastapi_app.get("/r")
230
+ def rotator():
231
+ conn = db()
232
+ try:
233
+ vid = epsilon_greedy_pick(conn, epsilon=0.1)
234
+ if not vid:
235
+ return PlainTextResponse("No variants yet.", status_code=404)
236
+ return RedirectResponse(url=f"/v/{vid}")
237
+ finally:
238
+ conn.close()
239
+
240
+ @fastapi_app.get("/v/{vid}")
241
+ def serve_variant(vid: str):
242
+ conn = db()
243
+ try:
244
+ row = fetch_variant(conn, vid)
245
+ if not row:
246
+ return PlainTextResponse("Variant not found", status_code=404)
247
+ record_impression(conn, vid)
248
+ headline = row["headline"]
249
+ body = row["body"]
250
+ cta = row["cta"]
251
+ cta_url = row["cta_url"] or "/convert?vid=" + vid
252
+ html = f"""
253
+ <html>
254
+ <head>
255
+ <meta charset="utf-8" />
256
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
257
+ <title>{headline}</title>
258
+ <style>
259
+ body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Noto Sans JP', sans-serif; padding: 24px; line-height: 1.6; }}
260
+ .card {{ max-width: 720px; margin: 0 auto; border: 1px solid #eee; border-radius: 16px; padding: 24px; box-shadow: 0 2px 12px rgba(0,0,0,.04); }}
261
+ .headline {{ font-size: 1.8rem; font-weight: 700; margin-bottom: 12px; }}
262
+ .body {{ font-size: 1rem; white-space: pre-wrap; }}
263
+ .cta {{ display: inline-block; margin-top: 20px; padding: 12px 18px; border-radius: 999px; text-decoration: none; background: black; color: white; }}
264
+ .meta {{ margin-top: 16px; color: #777; font-size: .85rem; }}
265
+ </style>
266
+ </head>
267
+ <body>
268
+ <div class="card">
269
+ <div class="headline">{headline}</div>
270
+ <div class="body">{body}</div>
271
+ <a class="cta" href="/click?vid={vid}&to={quote(cta_url, safe='')}" rel="nofollow">{cta}</a>
272
+ <div class="meta">Variant ID: {vid}</div>
273
+ </div>
274
+ </body>
275
+ </html>
276
+ """
277
+ return HTMLResponse(html)
278
+ finally:
279
+ conn.close()
280
+
281
+ @fastapi_app.get("/click")
282
+ def click(vid: str, to: str = "/convert"):
283
+ conn = db()
284
+ try:
285
+ row = fetch_variant(conn, vid)
286
+ if not row:
287
+ return PlainTextResponse("Variant not found", status_code=404)
288
+ record_click(conn, vid)
289
+ # 安全のため内部相対URLのみ許可(外部リダイレクトを避ける)
290
+ target = unquote(to)
291
+ if not target.startswith("/"):
292
+ target = "/convert?vid=" + vid
293
+ return RedirectResponse(url=target)
294
+ finally:
295
+ conn.close()
296
+
297
+ @fastapi_app.get("/convert")
298
+ def convert(vid: str = ""):
299
+ if not vid:
300
+ return PlainTextResponse("Missing vid", status_code=400)
301
+ conn = db()
302
+ try:
303
+ row = fetch_variant(conn, vid)
304
+ if not row:
305
+ return PlainTextResponse("Variant not found", status_code=404)
306
+ record_conversion(conn, vid)
307
+ html = f"""
308
+ <html><head><meta charset="utf-8"><title>Thank you</title></head>
309
+ <body style="font-family: -apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'Noto Sans JP',sans-serif;padding:24px;">
310
+ <h1>Thanks!</h1>
311
+ <p>Conversion recorded for Variant {vid}.</p>
312
+ <p><a href="/v/{vid}">Back</a> | <a href="/r">Rotator</a></p>
313
+ </body></html>
314
+ """
315
+ return HTMLResponse(html)
316
+ finally:
317
+ conn.close()
318
+
319
+ # ---- Gradio UI ----
320
+ def create_rows_table() -> str:
321
+ conn = db()
322
+ try:
323
+ rows = list_variants(conn)
324
+ if not rows:
325
+ return "まだバリアントがありません。"
326
+ # Markdownテーブル
327
+ header = "|Variant|Channel|Impr|Clicks|Conv|CTR|CVR|Headline|\n|---|---:|---:|---:|---:|---:|---:|---|\n"
328
+ lines = []
329
+ for r in rows:
330
+ imps = r["impressions"]
331
+ clk = r["clicks"]
332
+ conv = r["conversions"]
333
+ ctr = (clk / imps) if imps > 0 else 0.0
334
+ cvr = (conv / clk) if clk > 0 else 0.0
335
+ link_v = f"/v/{r['id']}"
336
+ lines.append(f"|`{r['id']}`|{r['channel']}|{imps}|{clk}|{conv}|{ctr:.3f}|{cvr:.3f}|[{r['headline']}]({link_v})|")
337
+ return header + "\n".join(lines)
338
+ finally:
339
+ conn.close()
340
+
341
+ def generate_and_save(goal, kpi_target, channel, product, audience, brand_guidelines, banned_words_csv, n_variants, cta_url):
342
+ banned_words = [w.strip() for w in (banned_words_csv or "").split(",") if w.strip()]
343
+ variants = generate_variants_with_openai(goal, float(kpi_target or 0), channel, product, audience, brand_guidelines, banned_words, int(n_variants or 3))
344
+ conn = db()
345
+ try:
346
+ for v in variants:
347
+ vid = gen_variant_id(v.get("name","variant"))
348
+ conn.execute(
349
+ """INSERT INTO variants(id, name, channel, target_metric, kpi_target, brand_guidelines, banned_words, product, audience, cta_url, headline, body, cta, image_prompt, created_at)
350
+ VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""",
351
+ (vid, v.get("name","variant"), channel, goal, float(kpi_target or 0), brand_guidelines, ",".join(banned_words),
352
+ product, audience, cta_url, v.get("headline",""), v.get("body",""), v.get("cta",""), v.get("image_prompt",""),
353
+ now_ts())
354
+ )
355
+ # 初期重みを均等に
356
+ upsert_weight(conn, vid, 1.0)
357
+ table_md = create_rows_table()
358
+ info = {
359
+ "rotator_url": "/r",
360
+ "variant_urls": [f"/v/{row['id']}" for row in list_variants(conn)]
361
+ }
362
+ return json.dumps(info, ensure_ascii=False, indent=2), table_md
363
+ finally:
364
+ conn.commit()
365
+ conn.close()
366
+
367
+ def reset_all():
368
+ conn = db()
369
+ try:
370
+ delete_all(conn)
371
+ return "全バリアントを削除しました。", create_rows_table()
372
+ finally:
373
+ conn.close()
374
+
375
+ def metrics_report():
376
+ return create_rows_table()
377
+
378
+ with gr.Blocks(title="キャンペーン生成器+自動ABテスト") as gradio_app:
379
+ gr.Markdown("# キャンペーン生成器+自動A/Bテスト(SNS/LP)\nOpenAIでクリエイティブ案を生成し、Space内で配信・計測(CTR/CVR)・最適化(ε-greedy)を行います。")
380
+ with gr.Row():
381
+ with gr.Column():
382
+ goal = gr.Dropdown(choices=["CTR","CVR"], label="目標指標", value="CTR")
383
+ kpi_target = gr.Number(label="KPI目標値(%表記は小数に:0.12など)", value=0.05)
384
+ channel = gr.Dropdown(choices=["SNS","LP"], label="チャネル", value="SNS")
385
+ product = gr.Textbox(label="商品・訴求ポイント", lines=3, placeholder="例:HbA1c測定アプリ。無料トライアルあり。")
386
+ audience = gr.Textbox(label="想定ターゲット", lines=2, placeholder="例:20-40代、健康意識の高いビジネスパーソン")
387
+ brand_guidelines = gr.Textbox(label="ブランドガイド(トーン/禁止表現/法務注意など)", lines=6, placeholder="例:医療効果を断定しない。誇大広告NG。丁寧で安心感のあるトーン。")
388
+ banned_words_csv = gr.Textbox(label="禁止ワード(カンマ区切り)", placeholder="例:無料、100%改善、絶対 など")
389
+ n_variants = gr.Slider(label="生成する案数", minimum=1, maximum=6, step=1, value=3)
390
+ cta_url = gr.Textbox(label="CTAリンク(相対パス推奨。未指定なら /convert?vid=...)", placeholder="/convert?vid=<自動付与> もしくは /thanks")
391
+ btn_gen = gr.Button("生成 → 反映")
392
+ btn_reset = gr.Button("全データ削除")
393
+ with gr.Column():
394
+ out_info = gr.Code(label="配信用URL(JSON)")
395
+ out_table = gr.Markdown(create_rows_table())
396
+
397
+ btn_gen.click(fn=generate_and_save, inputs=[goal, kpi_target, channel, product, audience, brand_guidelines, banned_words_csv, n_variants, cta_url], outputs=[out_info, out_table])
398
+ btn_reset.click(fn=reset_all, inputs=None, outputs=[out_info, out_table])
399
+
400
+ gr.Markdown("## メトリクス\n下記URLをシェアしてトラフィックを集めると、CTR/CVRが更新されます。/r はローテータ(多腕バンディット)です。")
401
+ rep_btn = gr.Button("最新レポートに更新")
402
+ rep_out = gr.Markdown(create_rows_table())
403
+ rep_btn.click(fn=metrics_report, inputs=None, outputs=rep_out)
404
+
405
+ # ---- 起動準備 ----
406
+ init_db()
407
+
408
+ # FastAPI ←→ Gradio のマウント
409
+ # Spacesのランタイムでは app という変数をエントリーポイントにするのが一般的。
410
+ app = fastapi_app
411
+ # Gradioを /gradio パスにマウント
412
+ from gradio.routes import mount_gradio_app
413
+ app = mount_gradio_app(app, gradio_app, path="/gradio")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.39.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.30.0
4
+ openai>=1.42.0
5
+ pydantic>=2.7.0
6
+ pandas>=2.2.2
7
+ numpy>=1.26.4