tomo2chin2 commited on
Commit
64edcc1
·
verified ·
1 Parent(s): a0f17f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -56
app.py CHANGED
@@ -6,7 +6,7 @@
6
  # それ以外は 5.x 対応フルロジックを一切カットせず
7
  # ===============================================================
8
 
9
- import os, time, tempfile, logging, threading, queue
10
  from io import BytesIO
11
  from concurrent.futures import ThreadPoolExecutor
12
 
@@ -126,8 +126,69 @@ class ScreenshotRequest(BaseModel):
126
  trim_whitespace: bool = True
127
  style: str = "standard"
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # ---------------------------------------------------------------
130
- # 補助関数(FontAwesome レイアウト / prompt 読み込み / Gemini 生成)
131
  # ---------------------------------------------------------------
132
  def enhance_font_awesome_layout(html_code: str) -> str:
133
  fa_preload = """
@@ -150,20 +211,6 @@ def enhance_font_awesome_layout(html_code: str) -> str:
150
  return html_code.replace('</head>', f'{fa_preload}{fa_css}</head>')
151
  return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'
152
 
153
- def load_system_instruction(style="standard") -> str:
154
- valid_styles = ["standard","cute","resort","cool","dental","school","KOKUGO"]
155
- if style not in valid_styles:
156
- style = "standard"
157
- local = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
158
- if os.path.exists(local):
159
- return open(local, encoding="utf-8").read()
160
- try:
161
- f = hf_hub_download("tomo2chin2/GURAREKOstlyle", f"{style}/prompt.txt", repo_type="dataset")
162
- return open(f, encoding="utf-8").read()
163
- except Exception:
164
- f = hf_hub_download("tomo2chin2/GURAREKOstlyle", "prompt.txt", repo_type="dataset")
165
- return open(f, encoding="utf-8").read()
166
-
167
  def generate_html_from_text(text: str, temperature=0.5, style="standard") -> str:
168
  # Updated: Use the new Google Genai client API
169
  api_key = os.environ["GEMINI_API_KEY"]
@@ -214,7 +261,7 @@ def trim_image_whitespace(img: Image.Image, threshold=248, padding=20) -> Image.
214
  return img
215
 
216
  # ---------------------------------------------------------------
217
- # HTML → スクショ (完全版ロジック)
218
  # ---------------------------------------------------------------
219
  def render_fullpage_screenshot(html_code: str, extension_percentage=6.0,
220
  trim_whitespace=True, driver=None) -> Image.Image:
@@ -225,46 +272,62 @@ def render_fullpage_screenshot(html_code: str, extension_percentage=6.0,
225
  driver = driver_pool.get_driver()
226
  from_pool = True
227
 
228
- # HTML 保存
229
  with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
230
  tmp_path = tmp.name
231
  tmp.write(html_code)
232
-
233
  driver.set_window_size(1200, 1000)
234
  driver.get("file://" + tmp_path)
235
-
236
- # body 出現を待機
237
- WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
238
-
239
- # リソースロード確認ループ(詳細ロジックは元コード準拠)
240
- max_wait, inc, waited = 5, 0.2, 0.0
241
- while waited < max_wait:
242
- state = driver.execute_script("""
243
- return {complete: document.readyState==='complete',
244
- imgs: document.images.length,
245
- loaded: Array.from(document.images).filter(i=>i.complete).length};
246
- """)
247
- if state['complete'] and (state['imgs']==0 or state['imgs']==state['loaded']):
248
- break
249
- time.sleep(inc); waited += inc
250
-
251
- # スクロールレンダリング
 
 
 
 
 
252
  total_h = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
253
  vh = driver.execute_script("return window.innerHeight")
 
 
254
  for i in range(max(1, min(5, total_h // vh))):
255
  driver.execute_script(f"window.scrollTo(0, {(vh-100)*i})")
256
  time.sleep(0.1)
257
  driver.execute_script("window.scrollTo(0,0)"); time.sleep(0.2)
258
 
259
- dims = driver.execute_script("""
260
- return {w: Math.max(document.body.scrollWidth, document.documentElement.scrollWidth),
261
- h: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)}
262
- """)
263
- w = min(max(dims['w'], 100), 2000)
264
- h = min(max(dims['h'], 100), 4000)
265
- h = int(h * (1 + extension_percentage / 100.0))
266
- driver.set_window_size(w, h); time.sleep(0.5)
267
-
 
 
 
 
 
 
 
 
 
268
  img = Image.open(BytesIO(driver.get_screenshot_as_png()))
269
  return trim_image_whitespace(img, padding=20) if trim_whitespace else img
270
 
@@ -279,18 +342,44 @@ def render_fullpage_screenshot(html_code: str, extension_percentage=6.0,
279
  except Exception: pass
280
 
281
  # ---------------------------------------------------------------
282
- # テキスト → スクショ (並列 API 呼び出し + ドライバ確保)
283
  # ---------------------------------------------------------------
284
- def text_to_screenshot_parallel(text, ext_perc, temp=0.5, trim_ws=True, style="standard") -> Image.Image:
285
- with ThreadPoolExecutor(max_workers=2) as exe:
 
 
 
286
  html_future = exe.submit(generate_html_from_text, text, temp, style)
287
  driver_future = exe.submit(driver_pool.get_driver)
 
 
 
288
  html_code = html_future.result()
289
  driver = driver_future.result()
 
 
290
  return render_fullpage_screenshot(html_code, ext_perc, trim_ws, driver)
291
 
292
- def text_to_screenshot(*args, **kwargs):
293
- return text_to_screenshot_parallel(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  # ===============================================================
296
  # FastAPI (★ redirect_slashes=False で自動 307 を殺す)
@@ -305,7 +394,7 @@ app.add_middleware(
305
  allow_headers=["*"],
306
  )
307
 
308
- # -------- API エンドポイントはそのまま --------
309
  @app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
310
  async def api_render_screenshot(req: ScreenshotRequest):
311
  img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
@@ -314,17 +403,40 @@ async def api_render_screenshot(req: ScreenshotRequest):
314
 
315
  @app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Gemini","Screenshot"])
316
  async def api_text_to_screenshot(req: GeminiRequest):
317
- img = text_to_screenshot_parallel(
318
  req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
319
  buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
320
  return StreamingResponse(buf, media_type="image/png")
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  # ===============================================================
323
  # Gradio UI (完全版 UI 定義)
324
  # ===============================================================
325
  def process_input(mode, text, ext, temp, trim, style):
326
  return render_fullpage_screenshot(text, ext, trim) if mode == "HTML入力" else \
327
- text_to_screenshot_parallel(text, ext, temp, trim, style)
328
 
329
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Origin()) as iface:
330
  gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
@@ -336,8 +448,8 @@ with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr
336
  ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
337
  value="standard", label="デザインスタイル", visible=False)
338
  with gr.Column(scale=2):
339
- ext = gr.Slider(0, 30, value=15, step=1, label="上下高さ拡張率(%)")
340
- temp = gr.Slider(0.0, 1.0, value=1.0, step=0.1,
341
  label="生成時の温度", visible=False)
342
  trim = gr.Checkbox(value=True, label="余白を自動トリミング")
343
  btn = gr.Button("生成")
@@ -353,7 +465,7 @@ with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr
353
  model_name = os.getenv('GEMINI_MODEL', 'gemini-1.5-pro')
354
  thinking_status = ""
355
  if model_name == "gemini-2.5-flash-preview-04-17":
356
- thinking_status = "(思考モード: オフ)"
357
 
358
  gr.Markdown(f"**API** `/api/screenshot`, `/api/text-to-screenshot` &nbsp;&nbsp; "
359
  f"使用モデル: `{model_name}` {thinking_status}")
@@ -372,6 +484,13 @@ def _root(): return RedirectResponse(GRADIO_PATH + "/")
372
  @app.get(GRADIO_PATH)
373
  def _no_slash(): return RedirectResponse(GRADIO_PATH + "/")
374
 
 
 
 
 
 
 
 
375
  # ===============================================================
376
  # ローカルデバッグ
377
  # ===============================================================
 
6
  # それ以外は 5.x 対応フルロジックを一切カットせず
7
  # ===============================================================
8
 
9
+ import os, time, tempfile, logging, threading, queue, zipfile
10
  from io import BytesIO
11
  from concurrent.futures import ThreadPoolExecutor
12
 
 
126
  trim_whitespace: bool = True
127
  style: str = "standard"
128
 
129
+ # バッチ処理用の新しいモデル
130
+ class BatchGeminiRequest(BaseModel):
131
+ texts: list[str]
132
+ extension_percentage: float = 10.0
133
+ temperature: float = 0.5
134
+ trim_whitespace: bool = True
135
+ style: str = "standard"
136
+
137
+ # ---------------------------------------------------------------
138
+ # システム指示のキャッシュ実装
139
+ # ---------------------------------------------------------------
140
+ # プロンプトキャッシュ - 頻繁に使用されるプロンプトを保存
141
+ _prompt_cache = {}
142
+
143
+ def load_system_instruction(style="standard") -> str:
144
+ """システム指示をロード (キャッシュ機能付き)"""
145
+ # キャッシュに存在すればそれを返す
146
+ if style in _prompt_cache:
147
+ return _prompt_cache[style]
148
+
149
+ valid_styles = ["standard","cute","resort","cool","dental","school","KOKUGO"]
150
+ if style not in valid_styles:
151
+ style = "standard"
152
+
153
+ local = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
154
+ prompt_text = ""
155
+
156
+ if os.path.exists(local):
157
+ prompt_text = open(local, encoding="utf-8").read()
158
+ else:
159
+ try:
160
+ f = hf_hub_download("tomo2chin2/GURAREKOstlyle", f"{style}/prompt.txt", repo_type="dataset")
161
+ prompt_text = open(f, encoding="utf-8").read()
162
+ except Exception:
163
+ f = hf_hub_download("tomo2chin2/GURAREKOstlyle", "prompt.txt", repo_type="dataset")
164
+ prompt_text = open(f, encoding="utf-8").read()
165
+
166
+ # キャッシュに保存
167
+ _prompt_cache[style] = prompt_text
168
+ return prompt_text
169
+
170
+ # ---------------------------------------------------------------
171
+ # 初期化時に全スタイルをキャッシュに先読み
172
+ # ---------------------------------------------------------------
173
+ def preload_all_prompts():
174
+ """アプリ起動時に全スタイルの指示を事前ロード"""
175
+ styles = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
176
+ logger.info("システム指示の先読み開始...")
177
+
178
+ with ThreadPoolExecutor(max_workers=len(styles)) as executor:
179
+ futures = {executor.submit(load_system_instruction, style): style for style in styles}
180
+ for future in futures:
181
+ style = futures[future]
182
+ try:
183
+ future.result() # 結果を取得
184
+ logger.info(f"スタイル '{style}' の指示を先読み完了")
185
+ except Exception as e:
186
+ logger.error(f"スタイル '{style}' の指示先読みに失敗: {e}")
187
+
188
+ logger.info(f"全 {len(_prompt_cache)} スタイルの指示先読み完了")
189
+
190
  # ---------------------------------------------------------------
191
+ # 補助関数(FontAwesome レイアウト / Gemini 生成)
192
  # ---------------------------------------------------------------
193
  def enhance_font_awesome_layout(html_code: str) -> str:
194
  fa_preload = """
 
211
  return html_code.replace('</head>', f'{fa_preload}{fa_css}</head>')
212
  return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def generate_html_from_text(text: str, temperature=0.5, style="standard") -> str:
215
  # Updated: Use the new Google Genai client API
216
  api_key = os.environ["GEMINI_API_KEY"]
 
261
  return img
262
 
263
  # ---------------------------------------------------------------
264
+ # HTML → スクショ 最適化版 (並列処理強化)
265
  # ---------------------------------------------------------------
266
  def render_fullpage_screenshot(html_code: str, extension_percentage=6.0,
267
  trim_whitespace=True, driver=None) -> Image.Image:
 
272
  driver = driver_pool.get_driver()
273
  from_pool = True
274
 
275
+ # HTML 保存と読み込みを並列化
276
  with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
277
  tmp_path = tmp.name
278
  tmp.write(html_code)
279
+
280
  driver.set_window_size(1200, 1000)
281
  driver.get("file://" + tmp_path)
282
+
283
+ # 非同期でリソースロード待機とスクリプト実行を行う
284
+ def wait_for_resources():
285
+ WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
286
+ max_wait, inc, waited = 5, 0.2, 0.0
287
+ while waited < max_wait:
288
+ state = driver.execute_script("""
289
+ return {complete: document.readyState==='complete',
290
+ imgs: document.images.length,
291
+ loaded: Array.from(document.images).filter(i=>i.complete).length};
292
+ """)
293
+ if state['complete'] and (state['imgs']==0 or state['imgs']==state['loaded']):
294
+ break
295
+ time.sleep(inc); waited += inc
296
+ return True
297
+
298
+ # リソース待機をスレッドプールで実行
299
+ with ThreadPoolExecutor(max_workers=1) as executor:
300
+ resource_wait = executor.submit(wait_for_resources)
301
+ resource_wait.result() # 待機完了を確認
302
+
303
+ # スクロールレンダリングを最適化
304
  total_h = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
305
  vh = driver.execute_script("return window.innerHeight")
306
+
307
+ # 並列スクロール処理は安定性の問題があるため、直列実行のままに
308
  for i in range(max(1, min(5, total_h // vh))):
309
  driver.execute_script(f"window.scrollTo(0, {(vh-100)*i})")
310
  time.sleep(0.1)
311
  driver.execute_script("window.scrollTo(0,0)"); time.sleep(0.2)
312
 
313
+ # サイズ計算と画像取得を並列化
314
+ def get_dimensions_and_resize():
315
+ dims = driver.execute_script("""
316
+ return {w: Math.max(document.body.scrollWidth, document.documentElement.scrollWidth),
317
+ h: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)}
318
+ """)
319
+ w = min(max(dims['w'], 100), 2000)
320
+ h = min(max(dims['h'], 100), 4000)
321
+ h = int(h * (1 + extension_percentage / 100.0))
322
+ driver.set_window_size(w, h)
323
+ time.sleep(0.5)
324
+ return w, h
325
+
326
+ with ThreadPoolExecutor(max_workers=1) as executor:
327
+ dims_future = executor.submit(get_dimensions_and_resize)
328
+ dims_future.result() # サイズ調整完了を確認
329
+
330
+ # スクリーンショット取得と画像処理
331
  img = Image.open(BytesIO(driver.get_screenshot_as_png()))
332
  return trim_image_whitespace(img, padding=20) if trim_whitespace else img
333
 
 
342
  except Exception: pass
343
 
344
  # ---------------------------------------------------------------
345
+ # テキスト → スクショ (並列処理強化版)
346
  # ---------------------------------------------------------------
347
+ def text_to_screenshot(text, ext_perc, temp=0.5, trim_ws=True, style="standard") -> Image.Image:
348
+ # 3つの並列タスク: HTML生成、ドライバ取得、必要なスタイルのシステム指示ロード
349
+ with ThreadPoolExecutor(max_workers=3) as exe:
350
+ # システム指示が未キャッシュの場合に備えて並列ロード
351
+ prompt_future = exe.submit(load_system_instruction, style)
352
  html_future = exe.submit(generate_html_from_text, text, temp, style)
353
  driver_future = exe.submit(driver_pool.get_driver)
354
+
355
+ # 結果取得
356
+ prompt_future.result() # プロンプトをキャッシュ確保
357
  html_code = html_future.result()
358
  driver = driver_future.result()
359
+
360
+ # 最適化されたスクリーンショット関数を使用
361
  return render_fullpage_screenshot(html_code, ext_perc, trim_ws, driver)
362
 
363
+ # ---------------------------------------------------------------
364
+ # テキスト → スクショ (複数同時処理版)
365
+ # ---------------------------------------------------------------
366
+ def batch_text_to_screenshot(texts, ext_perc, temp=0.5, trim_ws=True, style="standard") -> list:
367
+ """複数テキストを同時に処理"""
368
+ with ThreadPoolExecutor(max_workers=min(len(texts), 3)) as exe:
369
+ futures = [exe.submit(text_to_screenshot, text, ext_perc, temp, trim_ws, style)
370
+ for text in texts]
371
+ return [f.result() for f in futures]
372
+
373
+ # ---------------------------------------------------------------
374
+ # アプリ初期化時に実行する処理
375
+ # ---------------------------------------------------------------
376
+ def initialize_app():
377
+ """アプリケーション初期化処理"""
378
+ # システム指示を事前にキャッシュにロード
379
+ preload_all_prompts()
380
+
381
+ # その他の初期化処理があればここに追加
382
+ logger.info("アプリケーション初期化完了")
383
 
384
  # ===============================================================
385
  # FastAPI (★ redirect_slashes=False で自動 307 を殺す)
 
394
  allow_headers=["*"],
395
  )
396
 
397
+ # -------- API エンドポイント --------
398
  @app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
399
  async def api_render_screenshot(req: ScreenshotRequest):
400
  img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
 
403
 
404
  @app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Gemini","Screenshot"])
405
  async def api_text_to_screenshot(req: GeminiRequest):
406
+ img = text_to_screenshot(
407
  req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
408
  buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
409
  return StreamingResponse(buf, media_type="image/png")
410
 
411
+ # バッチ処理用の新しいエンドポイント
412
+ @app.post("/api/batch-text-to-screenshot", tags=["Gemini","Screenshot"])
413
+ async def api_batch_text_to_screenshot(req: BatchGeminiRequest):
414
+ # 複数テキストを並列処理
415
+ images = batch_text_to_screenshot(
416
+ req.texts, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
417
+
418
+ # 結果をZIP形式で返す
419
+ buf = BytesIO()
420
+ with zipfile.ZipFile(buf, 'w') as zf:
421
+ for i, img in enumerate(images):
422
+ img_buf = BytesIO()
423
+ img.save(img_buf, format="PNG")
424
+ img_buf.seek(0)
425
+ zf.writestr(f"screenshot_{i+1}.png", img_buf.getvalue())
426
+
427
+ buf.seek(0)
428
+ return StreamingResponse(
429
+ buf,
430
+ media_type="application/zip",
431
+ headers={"Content-Disposition": "attachment; filename=screenshots.zip"}
432
+ )
433
+
434
  # ===============================================================
435
  # Gradio UI (完全版 UI 定義)
436
  # ===============================================================
437
  def process_input(mode, text, ext, temp, trim, style):
438
  return render_fullpage_screenshot(text, ext, trim) if mode == "HTML入力" else \
439
+ text_to_screenshot(text, ext, temp, trim, style)
440
 
441
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Origin()) as iface:
442
  gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
 
448
  ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
449
  value="standard", label="デザインスタイル", visible=False)
450
  with gr.Column(scale=2):
451
+ ext = gr.Slider(0, 30, value=10, step=1, label="上下高さ拡張率(%)")
452
+ temp = gr.Slider(0.0, 1.0, value=0.5, step=0.1,
453
  label="生成時の温度", visible=False)
454
  trim = gr.Checkbox(value=True, label="余白を自動トリミング")
455
  btn = gr.Button("生成")
 
465
  model_name = os.getenv('GEMINI_MODEL', 'gemini-1.5-pro')
466
  thinking_status = ""
467
  if model_name == "gemini-2.5-flash-preview-04-17":
468
+ thinking_status = "(思考モード: オフ、最大トークン: 50000)"
469
 
470
  gr.Markdown(f"**API** `/api/screenshot`, `/api/text-to-screenshot` &nbsp;&nbsp; "
471
  f"使用モデル: `{model_name}` {thinking_status}")
 
484
  @app.get(GRADIO_PATH)
485
  def _no_slash(): return RedirectResponse(GRADIO_PATH + "/")
486
 
487
+ # アプリケーション起動時の初期化
488
+ @app.on_event("startup")
489
+ async def startup_event():
490
+ # バックグラウンドで初期化処理を実行
491
+ threading.Thread(target=initialize_app).start()
492
+ logger.info("アプリケーション起動: 並列処理による最適化を適用")
493
+
494
  # ===============================================================
495
  # ローカルデバッグ
496
  # ===============================================================