tomo2chin2 commited on
Commit
5a6269f
·
verified ·
1 Parent(s): b93750a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -137
app.py CHANGED
@@ -1,11 +1,9 @@
1
- # app.py (Gradio 5.x 対応フルバージョン)
2
- # =========================================
3
- # 主要変更点
4
- # gradio==4.19.2 → gradio>=5.29.0 (requirements.txt で指定)
5
- # Blocks theme を Origin() に変更(4.x の外観を継承)
6
- # mount_gradio_app に ssr_mode=False を追加
7
- # • Gradio 内部アセットの手動マウントを撤去(5.x では不要)
8
- # =========================================
9
 
10
  import os
11
  import time
@@ -16,12 +14,12 @@ import queue
16
  from io import BytesIO
17
  from concurrent.futures import ThreadPoolExecutor
18
 
19
- import numpy as np # 画像トリミング高速化
20
  from PIL import Image
21
 
22
- import gradio as gr # ★ 5.x
23
  from fastapi import FastAPI, HTTPException
24
- from fastapi.responses import StreamingResponse
25
  from fastapi.middleware.cors import CORSMiddleware
26
  from pydantic import BaseModel
27
 
@@ -31,55 +29,61 @@ from selenium.webdriver.common.by import By
31
  from selenium.webdriver.support.ui import WebDriverWait
32
  from selenium.webdriver.support import expected_conditions as EC
33
 
34
- import google.generativeai as genai # Gemini
35
  from huggingface_hub import hf_hub_download
36
 
37
- # -------------------- ロガー設定 --------------------
 
 
 
38
  logging.basicConfig(level=logging.INFO)
39
  logger = logging.getLogger(__name__)
40
 
41
- # ====================================================
42
- # WebDriver プール実装
43
- # ====================================================
 
44
  class WebDriverPool:
 
45
  def __init__(self, max_drivers: int = 3):
46
- self.driver_queue = queue.Queue()
47
  self.max_drivers = max_drivers
48
  self.lock = threading.Lock()
49
  self.count = 0
50
  logger.info(f"WebDriver プール初期化: 最大 {max_drivers}")
51
 
52
- def get_driver(self):
53
- # 既存
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if not self.driver_queue.empty():
55
- logger.info("既存 WebDriver を取得")
56
  return self.driver_queue.get()
57
 
58
- # 新規
59
  with self.lock:
60
  if self.count < self.max_drivers:
61
  self.count += 1
62
  logger.info(f"新規 WebDriver 作成 ({self.count}/{self.max_drivers})")
63
- options = Options()
64
- options.add_argument("--headless")
65
- options.add_argument("--no-sandbox")
66
- options.add_argument("--disable-dev-shm-usage")
67
- options.add_argument("--force-device-scale-factor=1")
68
- options.add_argument("--disable-features=NetworkService")
69
- options.add_argument("--dns-prefetch-disable")
70
-
71
- driver_path = os.environ.get("CHROMEDRIVER_PATH")
72
- if driver_path and os.path.exists(driver_path):
73
- logger.info(f"環境変数 CHROMEDRIVER_PATH 使用: {driver_path}")
74
- service = webdriver.ChromeService(executable_path=driver_path)
75
- return webdriver.Chrome(service=service, options=options)
76
- return webdriver.Chrome(options=options)
77
-
78
- # プール満杯
79
- logger.info("プール満杯。返却待ち…")
80
  return self.driver_queue.get()
81
 
82
- def release_driver(self, driver):
83
  if driver:
84
  try:
85
  driver.get("about:blank")
@@ -90,18 +94,18 @@ class WebDriverPool:
90
  self.driver_queue.put(driver)
91
  logger.info("WebDriver をプールに返却")
92
  except Exception as e:
93
- logger.error(f"返却時エラー: {e}")
94
  driver.quit()
95
  with self.lock:
96
  self.count -= 1
97
 
98
  def close_all(self):
99
- logger.info("WebDriver 終了処理")
100
  closed = 0
101
  while not self.driver_queue.empty():
102
  try:
103
- driver = self.driver_queue.get(block=False)
104
- driver.quit()
105
  closed += 1
106
  except queue.Empty:
107
  break
@@ -111,12 +115,14 @@ class WebDriverPool:
111
  self.count = 0
112
  logger.info(f"{closed} 個の WebDriver を終了")
113
 
 
114
  # グローバルプール
115
  driver_pool = WebDriverPool(max_drivers=int(os.environ.get("MAX_WEBDRIVERS", "3")))
116
 
117
- # ====================================================
 
118
  # Pydantic モデル
119
- # ====================================================
120
  class GeminiRequest(BaseModel):
121
  text: str
122
  extension_percentage: float = 10.0
@@ -124,17 +130,19 @@ class GeminiRequest(BaseModel):
124
  trim_whitespace: bool = True
125
  style: str = "standard"
126
 
 
127
  class ScreenshotRequest(BaseModel):
128
  html_code: str
129
  extension_percentage: float = 10.0
130
  trim_whitespace: bool = True
131
  style: str = "standard"
132
 
133
- # ====================================================
 
134
  # 補助関数
135
- # ====================================================
136
  def enhance_font_awesome_layout(html_code: str) -> str:
137
- """Font Awesome レイアウトを調整し preload タグを付与"""
138
  fa_preload = """
139
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
140
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
@@ -162,8 +170,9 @@ def enhance_font_awesome_layout(html_code: str) -> str:
162
  return html_code[:body_start] + f'<head>{fa_preload}{fa_css}</head>' + html_code[body_start:]
163
  return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'
164
 
 
165
  def load_system_instruction(style: str = "standard") -> str:
166
- """テーマ別 prompt.txt を読み込み"""
167
  valid = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
168
  if style not in valid:
169
  logger.warning(f"無効 style '{style}' → 'standard'")
@@ -183,7 +192,7 @@ def load_system_instruction(style: str = "standard") -> str:
183
  with open(file_path, encoding="utf-8") as f:
184
  return f.read()
185
  except Exception as e:
186
- logger.warning(f"HuggingFace 取得失敗 ({e}) → デフォルト prompt.txt")
187
  file_path = hf_hub_download(
188
  repo_id="tomo2chin2/GURAREKOstlyle",
189
  filename="prompt.txt",
@@ -192,21 +201,23 @@ def load_system_instruction(style: str = "standard") -> str:
192
  with open(file_path, encoding="utf-8") as f:
193
  return f.read()
194
 
 
195
  def generate_html_from_text(text: str, temperature: float = 0.5, style: str = "standard") -> str:
196
- """Gemini HTML 生成"""
197
  api_key = os.environ.get("GEMINI_API_KEY")
198
  if not api_key:
199
  raise ValueError("環境変数 GEMINI_API_KEY が未設定")
200
-
201
  model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
 
202
  genai.configure(api_key=api_key)
 
203
 
204
  system_instruction = load_system_instruction(style)
205
- model = genai.GenerativeModel(model_name)
206
 
207
  generation_config = dict(
208
- temperature=temperature, top_p=0.7, top_k=20, max_output_tokens=8192,
209
- candidate_count=1
210
  )
211
  safety_settings = [
212
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
@@ -214,67 +225,62 @@ def generate_html_from_text(text: str, temperature: float = 0.5, style: str = "s
214
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
215
  {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
216
  ]
217
- prompt = f"{system_instruction}\n\n{text}"
218
- response = model.generate_content(prompt, generation_config=generation_config, safety_settings=safety_settings)
219
- raw = response.text
220
 
221
- html_start = raw.find("```html")
222
- html_end = raw.rfind("```")
223
- if html_start != -1 and html_end != -1 and html_start < html_end:
224
- html_code = raw[html_start + 7:html_end].strip()
225
- else:
226
- html_code = raw
227
 
228
- html_code = enhance_font_awesome_layout(html_code)
229
- return html_code
230
 
231
  def trim_image_whitespace(image: Image.Image, threshold: int = 250, padding: int = 10) -> Image.Image:
232
- """白余白トリミング(NumPy 高速化)"""
233
  gray = image.convert("L")
234
  arr = np.array(gray)
235
  mask = arr < threshold
236
- rows = np.any(mask, axis=1)
237
- cols = np.any(mask, axis=0)
238
  if np.any(rows) and np.any(cols):
239
  y_min, y_max = np.where(rows)[0][[0, -1]]
240
  x_min, x_max = np.where(cols)[0][[0, -1]]
241
- y_min = max(0, y_min - padding)
242
- x_min = max(0, x_min - padding)
243
- y_max = min(image.height - 1, y_max + padding)
244
- x_max = min(image.width - 1, x_max + padding)
245
- return image.crop((x_min, y_min, x_max + 1, y_max + 1))
246
  return image
247
 
248
- # ----------------------------------------------------
 
249
  # HTML → スクリーンショット
250
- # ----------------------------------------------------
251
  def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0,
252
- trim_whitespace: bool = True,
253
- driver=None) -> Image.Image:
254
- driver_from_pool = False
255
  tmp_path = None
 
256
  try:
257
  if driver is None:
258
  driver = driver_pool.get_driver()
259
  driver_from_pool = True
260
 
 
261
  with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
262
  tmp_path = tmp.name
263
  tmp.write(html_code)
264
 
265
- initial_w, initial_h = 1200, 1000
266
- driver.set_window_size(initial_w, initial_h)
267
- driver.get(f"file://{tmp_path}")
268
 
 
269
  WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
270
 
271
- total_h = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
272
- viewport_h = driver.execute_script("return window.innerHeight")
273
- scrolls = max(1, min(5, total_h // viewport_h))
274
- for i in range(scrolls):
275
- driver.execute_script(f"window.scrollTo(0, {i * (viewport_h - 100)})")
276
  time.sleep(0.1)
277
- driver.execute_script("window.scrollTo(0, 0)")
278
  time.sleep(0.2)
279
 
280
  dims = driver.execute_script("""
@@ -283,17 +289,14 @@ def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0
283
  height: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)
284
  }
285
  """)
286
- w = min(max(dims["width"], 100), 2000)
287
- h = min(max(dims["height"], 100), 4000)
288
- h = int(h * (1 + extension_percentage / 100.0))
289
- driver.set_window_size(w, h)
290
  time.sleep(0.5)
291
 
292
- png = driver.get_screenshot_as_png()
293
- img = Image.open(BytesIO(png))
294
- if trim_whitespace:
295
- img = trim_image_whitespace(img, threshold=248, padding=20)
296
- return img
297
  except Exception as e:
298
  logger.error(f"Screenshot error: {e}", exc_info=True)
299
  return Image.new("RGB", (1, 1), (0, 0, 0))
@@ -306,28 +309,32 @@ def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0
306
  except Exception:
307
  pass
308
 
309
- # ----------------------------------------------------
 
310
  # テキスト → スクリーンショット(並列)
311
- # ----------------------------------------------------
312
  def text_to_screenshot_parallel(text: str, extension_percentage: float, temperature: float = 0.5,
313
  trim_whitespace: bool = True, style: str = "standard") -> Image.Image:
314
  start = time.time()
315
- with ThreadPoolExecutor(max_workers=2) as exe:
316
- html_fut = exe.submit(generate_html_from_text, text, temperature, style)
317
- driver_fut = exe.submit(driver_pool.get_driver)
318
 
319
- html_code = html_fut.result()
320
- driver = driver_fut.result()
321
  img = render_fullpage_screenshot(html_code, extension_percentage, trim_whitespace, driver)
322
- logger.info(f"並列処理 完了 {time.time() - start:.2f}s")
323
  return img
324
 
325
- def text_to_screenshot(*args, **kwargs) -> Image.Image:
 
 
326
  return text_to_screenshot_parallel(*args, **kwargs)
327
 
328
- # ====================================================
329
- # FastAPI セットアップ
330
- # ====================================================
 
331
  app = FastAPI()
332
  app.add_middleware(
333
  CORSMiddleware,
@@ -335,75 +342,100 @@ app.add_middleware(
335
  allow_methods=["*"], allow_headers=["*"]
336
  )
337
 
338
- # ------------ API エンドポイント ---------------
339
  @app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
340
  async def api_render_screenshot(req: ScreenshotRequest):
341
  img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
342
- buf = BytesIO()
343
- img.save(buf, format="PNG"); buf.seek(0)
344
  return StreamingResponse(buf, media_type="image/png")
345
 
 
346
  @app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Screenshot", "Gemini"])
347
  async def api_text_to_screenshot(req: GeminiRequest):
348
- img = text_to_screenshot_parallel(req.text, req.extension_percentage,
349
- req.temperature, req.trim_whitespace, req.style)
350
- buf = BytesIO()
351
- img.save(buf, format="PNG"); buf.seek(0)
352
  return StreamingResponse(buf, media_type="image/png")
353
 
354
- # ====================================================
 
355
  # Gradio UI
356
- # ====================================================
357
- def process_input(mode, input_text, ext_perc, temp, trim_ws, style):
358
  if mode == "HTML入力":
359
- return render_fullpage_screenshot(input_text, ext_perc, trim_ws)
360
- return text_to_screenshot_parallel(input_text, ext_perc, temp, trim_ws, style)
 
361
 
362
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)",
363
  theme=gr.themes.Origin()) as iface:
364
- gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック")
365
  with gr.Row():
366
  input_mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード")
 
367
  input_text = gr.Textbox(lines=15, label="入力")
 
368
  with gr.Row():
369
- style_dropdown = gr.Dropdown(["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
370
- value="standard", label="デザインスタイル", visible=False)
 
371
  with gr.Column(scale=2):
372
  ext_slider = gr.Slider(0, 30, value=10, step=1, label="上下高さ拡張率(%)")
373
  temp_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.1,
374
  label="生成時の温度", visible=False)
375
- trim_ws_chk = gr.Checkbox(value=True, label="余白を自動トリミング")
 
 
376
  submit_btn = gr.Button("生成")
377
  out_img = gr.Image(type="pil", label="スクリーンショット")
378
 
379
- def toggle_controls(mode):
380
  is_text = mode == "テキスト入力"
381
  return [gr.update(visible=is_text), gr.update(visible=is_text)]
382
- input_mode.change(toggle_controls, input_mode, [temp_slider, style_dropdown])
383
- submit_btn.click(process_input,
384
- [input_mode, input_text, ext_slider, temp_slider, trim_ws_chk, style_dropdown],
385
- out_img)
 
 
386
 
387
- gemini_model = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
388
  gr.Markdown(f"""
389
  **API**
390
  - `/api/screenshot` – HTML → PNG
391
  - `/api/text-to-screenshot` – テキスト → インフォグラフィック PNG
392
 
393
  **設定**
394
- - 使用モデル: `{gemini_model}`
395
- - スタイル: standard / cute / resort / cool / dental / school / KOKUGO
396
  - WebDriver 最大数: {driver_pool.max_drivers}
397
  """)
398
 
399
- # FastAPI へマウント(SSR 無効)
400
- app = gr.mount_gradio_app(app, iface, path="/", ssr_mode=False)
401
 
402
- # ローカル実行用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  if __name__ == "__main__":
404
  import uvicorn
 
405
  uvicorn.run(app, host="0.0.0.0", port=7860)
406
 
407
- # 終了時のクリーンアップ
 
 
 
408
  import atexit
409
  atexit.register(driver_pool.close_all)
 
1
+ # ===============================================================
2
+ # app.py ― Gradio 5.x + FastAPI + Gemini + Selenium
3
+ # ・Gradio 5.29.0 以上を前提に最適化
4
+ # ・UI/API を維持したまま 307 ループを解消
5
+ # ・/gradio にサブマウントし / はリダイレクトのみ
6
+ # ===============================================================
 
 
7
 
8
  import os
9
  import time
 
14
  from io import BytesIO
15
  from concurrent.futures import ThreadPoolExecutor
16
 
17
+ import numpy as np
18
  from PIL import Image
19
 
20
+ import gradio as gr
21
  from fastapi import FastAPI, HTTPException
22
+ from fastapi.responses import StreamingResponse, RedirectResponse
23
  from fastapi.middleware.cors import CORSMiddleware
24
  from pydantic import BaseModel
25
 
 
29
  from selenium.webdriver.support.ui import WebDriverWait
30
  from selenium.webdriver.support import expected_conditions as EC
31
 
32
+ import google.generativeai as genai
33
  from huggingface_hub import hf_hub_download
34
 
35
+
36
+ # ===============================================================
37
+ # ロガー
38
+ # ===============================================================
39
  logging.basicConfig(level=logging.INFO)
40
  logger = logging.getLogger(__name__)
41
 
42
+
43
+ # ===============================================================
44
+ # WebDriver プール
45
+ # ===============================================================
46
  class WebDriverPool:
47
+ """複数の WebDriver を使い回すシンプルなプール"""
48
  def __init__(self, max_drivers: int = 3):
49
+ self.driver_queue: "queue.Queue[webdriver.Chrome]" = queue.Queue()
50
  self.max_drivers = max_drivers
51
  self.lock = threading.Lock()
52
  self.count = 0
53
  logger.info(f"WebDriver プール初期化: 最大 {max_drivers}")
54
 
55
+ def _create_driver(self) -> webdriver.Chrome:
56
+ options = Options()
57
+ options.add_argument("--headless")
58
+ options.add_argument("--no-sandbox")
59
+ options.add_argument("--disable-dev-shm-usage")
60
+ options.add_argument("--force-device-scale-factor=1")
61
+ options.add_argument("--disable-features=NetworkService")
62
+ options.add_argument("--dns-prefetch-disable")
63
+
64
+ chromedriver_path = os.environ.get("CHROMEDRIVER_PATH")
65
+ if chromedriver_path and os.path.exists(chromedriver_path):
66
+ logger.info(f"環境変数 CHROMEDRIVER_PATH を使用: {chromedriver_path}")
67
+ service = webdriver.ChromeService(executable_path=chromedriver_path)
68
+ return webdriver.Chrome(service=service, options=options)
69
+ return webdriver.Chrome(options=options)
70
+
71
+ def get_driver(self) -> webdriver.Chrome:
72
  if not self.driver_queue.empty():
73
+ logger.info("プールから既存 WebDriver を取得")
74
  return self.driver_queue.get()
75
 
 
76
  with self.lock:
77
  if self.count < self.max_drivers:
78
  self.count += 1
79
  logger.info(f"新規 WebDriver 作成 ({self.count}/{self.max_drivers})")
80
+ return self._create_driver()
81
+
82
+ # ここに来るのはプール満杯時
83
+ logger.info("WebDriver プール満杯。空きを待機")
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return self.driver_queue.get()
85
 
86
+ def release_driver(self, driver: webdriver.Chrome):
87
  if driver:
88
  try:
89
  driver.get("about:blank")
 
94
  self.driver_queue.put(driver)
95
  logger.info("WebDriver をプールに返却")
96
  except Exception as e:
97
+ logger.error(f"返却エラー: {e}")
98
  driver.quit()
99
  with self.lock:
100
  self.count -= 1
101
 
102
  def close_all(self):
103
+ logger.info("WebDriver 全終了")
104
  closed = 0
105
  while not self.driver_queue.empty():
106
  try:
107
+ drv = self.driver_queue.get(block=False)
108
+ drv.quit()
109
  closed += 1
110
  except queue.Empty:
111
  break
 
115
  self.count = 0
116
  logger.info(f"{closed} 個の WebDriver を終了")
117
 
118
+
119
  # グローバルプール
120
  driver_pool = WebDriverPool(max_drivers=int(os.environ.get("MAX_WEBDRIVERS", "3")))
121
 
122
+
123
+ # ===============================================================
124
  # Pydantic モデル
125
+ # ===============================================================
126
  class GeminiRequest(BaseModel):
127
  text: str
128
  extension_percentage: float = 10.0
 
130
  trim_whitespace: bool = True
131
  style: str = "standard"
132
 
133
+
134
  class ScreenshotRequest(BaseModel):
135
  html_code: str
136
  extension_percentage: float = 10.0
137
  trim_whitespace: bool = True
138
  style: str = "standard"
139
 
140
+
141
+ # ===============================================================
142
  # 補助関数
143
+ # ===============================================================
144
  def enhance_font_awesome_layout(html_code: str) -> str:
145
+ """Font Awesome をプリロード + レイアウト微調整 CSS を注入"""
146
  fa_preload = """
147
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
148
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
 
170
  return html_code[:body_start] + f'<head>{fa_preload}{fa_css}</head>' + html_code[body_start:]
171
  return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'
172
 
173
+
174
  def load_system_instruction(style: str = "standard") -> str:
175
+ """スタイル別の prompt.txt をローカル or HF から読み込み"""
176
  valid = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
177
  if style not in valid:
178
  logger.warning(f"無効 style '{style}' → 'standard'")
 
192
  with open(file_path, encoding="utf-8") as f:
193
  return f.read()
194
  except Exception as e:
195
+ logger.warning(f"HF 取得失敗 ({e}) → デフォルト prompt.txt")
196
  file_path = hf_hub_download(
197
  repo_id="tomo2chin2/GURAREKOstlyle",
198
  filename="prompt.txt",
 
201
  with open(file_path, encoding="utf-8") as f:
202
  return f.read()
203
 
204
+
205
  def generate_html_from_text(text: str, temperature: float = 0.5, style: str = "standard") -> str:
206
+ """Gemini で与えられたテキストを HTML に整形して返す"""
207
  api_key = os.environ.get("GEMINI_API_KEY")
208
  if not api_key:
209
  raise ValueError("環境変数 GEMINI_API_KEY が未設定")
 
210
  model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
211
+
212
  genai.configure(api_key=api_key)
213
+ model = genai.GenerativeModel(model_name)
214
 
215
  system_instruction = load_system_instruction(style)
216
+ prompt = f"{system_instruction}\n\n{text}"
217
 
218
  generation_config = dict(
219
+ temperature=temperature, top_p=0.7, top_k=20,
220
+ max_output_tokens=8192, candidate_count=1
221
  )
222
  safety_settings = [
223
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
 
225
  {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
226
  {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
227
  ]
228
+ resp = model.generate_content(prompt, generation_config=generation_config, safety_settings=safety_settings)
229
+ raw = resp.text
 
230
 
231
+ # ```html ``` 抜き取り
232
+ start = raw.find("```html")
233
+ end = raw.rfind("```")
234
+ html_code = raw[start + 7:end].strip() if start != -1 and end != -1 else raw
235
+ return enhance_font_awesome_layout(html_code)
 
236
 
 
 
237
 
238
  def trim_image_whitespace(image: Image.Image, threshold: int = 250, padding: int = 10) -> Image.Image:
239
+ """白背景をトリミング"""
240
  gray = image.convert("L")
241
  arr = np.array(gray)
242
  mask = arr < threshold
243
+ rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
 
244
  if np.any(rows) and np.any(cols):
245
  y_min, y_max = np.where(rows)[0][[0, -1]]
246
  x_min, x_max = np.where(cols)[0][[0, -1]]
247
+ return image.crop((
248
+ max(0, x_min - padding), max(0, y_min - padding),
249
+ min(image.width - 1, x_max + padding), min(image.height - 1, y_max + padding)
250
+ ))
 
251
  return image
252
 
253
+
254
+ # ===============================================================
255
  # HTML → スクリーンショット
256
+ # ===============================================================
257
  def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0,
258
+ trim_whitespace: bool = True, driver=None) -> Image.Image:
 
 
259
  tmp_path = None
260
+ driver_from_pool = False
261
  try:
262
  if driver is None:
263
  driver = driver_pool.get_driver()
264
  driver_from_pool = True
265
 
266
+ # 一時 HTML 保存
267
  with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
268
  tmp_path = tmp.name
269
  tmp.write(html_code)
270
 
271
+ driver.set_window_size(1200, 1000)
272
+ driver.get("file://" + tmp_path)
 
273
 
274
+ # body 待機
275
  WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
276
 
277
+ total_height = driver.execute_script(
278
+ "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
279
+ viewport_height = driver.execute_script("return window.innerHeight")
280
+ for i in range(max(1, min(5, total_height // viewport_height))):
281
+ driver.execute_script(f"window.scrollTo(0, {i * (viewport_height - 100)})")
282
  time.sleep(0.1)
283
+ driver.execute_script("window.scrollTo(0,0)")
284
  time.sleep(0.2)
285
 
286
  dims = driver.execute_script("""
 
289
  height: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)
290
  }
291
  """)
292
+ width = min(max(dims["width"], 100), 2000)
293
+ height = min(max(dims["height"], 100), 4000)
294
+ height = int(height * (1 + extension_percentage / 100.0))
295
+ driver.set_window_size(width, height)
296
  time.sleep(0.5)
297
 
298
+ img = Image.open(BytesIO(driver.get_screenshot_as_png()))
299
+ return trim_image_whitespace(img, 248, 20) if trim_whitespace else img
 
 
 
300
  except Exception as e:
301
  logger.error(f"Screenshot error: {e}", exc_info=True)
302
  return Image.new("RGB", (1, 1), (0, 0, 0))
 
309
  except Exception:
310
  pass
311
 
312
+
313
+ # ===============================================================
314
  # テキスト → スクリーンショット(並列)
315
+ # ===============================================================
316
  def text_to_screenshot_parallel(text: str, extension_percentage: float, temperature: float = 0.5,
317
  trim_whitespace: bool = True, style: str = "standard") -> Image.Image:
318
  start = time.time()
319
+ with ThreadPoolExecutor(max_workers=2) as ex:
320
+ html_future = ex.submit(generate_html_from_text, text, temperature, style)
321
+ driver_future = ex.submit(driver_pool.get_driver)
322
 
323
+ html_code = html_future.result()
324
+ driver = driver_future.result()
325
  img = render_fullpage_screenshot(html_code, extension_percentage, trim_whitespace, driver)
326
+ logger.info(f"並列生成完了: {time.time() - start:.2f}s")
327
  return img
328
 
329
+
330
+ def text_to_screenshot(*args, **kwargs):
331
+ """後方互換用エイリアス"""
332
  return text_to_screenshot_parallel(*args, **kwargs)
333
 
334
+
335
+ # ===============================================================
336
+ # FastAPI
337
+ # ===============================================================
338
  app = FastAPI()
339
  app.add_middleware(
340
  CORSMiddleware,
 
342
  allow_methods=["*"], allow_headers=["*"]
343
  )
344
 
345
+
346
  @app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
347
  async def api_render_screenshot(req: ScreenshotRequest):
348
  img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
349
+ buf = BytesIO(); img.save(buf, "PNG"); buf.seek(0)
 
350
  return StreamingResponse(buf, media_type="image/png")
351
 
352
+
353
  @app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Screenshot", "Gemini"])
354
  async def api_text_to_screenshot(req: GeminiRequest):
355
+ img = text_to_screenshot_parallel(
356
+ req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
357
+ buf = BytesIO(); img.save(buf, "PNG"); buf.seek(0)
 
358
  return StreamingResponse(buf, media_type="image/png")
359
 
360
+
361
+ # ===============================================================
362
  # Gradio UI
363
+ # ===============================================================
364
+ def process_input(mode, text, ext, temp, trim, style):
365
  if mode == "HTML入力":
366
+ return render_fullpage_screenshot(text, ext, trim)
367
+ return text_to_screenshot_parallel(text, ext, temp, trim, style)
368
+
369
 
370
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)",
371
  theme=gr.themes.Origin()) as iface:
372
+ gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
373
  with gr.Row():
374
  input_mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード")
375
+
376
  input_text = gr.Textbox(lines=15, label="入力")
377
+
378
  with gr.Row():
379
+ style_dropdown = gr.Dropdown(
380
+ ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
381
+ value="standard", label="デザインスタイル", visible=False)
382
  with gr.Column(scale=2):
383
  ext_slider = gr.Slider(0, 30, value=10, step=1, label="上下高さ拡張率(%)")
384
  temp_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.1,
385
  label="生成時の温度", visible=False)
386
+
387
+ trim_chk = gr.Checkbox(value=True, label="余白を自動トリミング")
388
+
389
  submit_btn = gr.Button("生成")
390
  out_img = gr.Image(type="pil", label="スクリーンショット")
391
 
392
+ def toggle(mode):
393
  is_text = mode == "テキスト入力"
394
  return [gr.update(visible=is_text), gr.update(visible=is_text)]
395
+ input_mode.change(toggle, input_mode, [temp_slider, style_dropdown])
396
+
397
+ submit_btn.click(
398
+ process_input,
399
+ [input_mode, input_text, ext_slider, temp_slider, trim_chk, style_dropdown],
400
+ out_img)
401
 
402
+ model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
403
  gr.Markdown(f"""
404
  **API**
405
  - `/api/screenshot` – HTML → PNG
406
  - `/api/text-to-screenshot` – テキスト → インフォグラフィック PNG
407
 
408
  **設定**
409
+ - 使用モデル: `{model_name}`
410
+ - 対応スタイル: standard / cute / resort / cool / dental / school / KOKUGO
411
  - WebDriver 最大数: {driver_pool.max_drivers}
412
  """)
413
 
 
 
414
 
415
+ # ===============================================================
416
+ # FastAPI へマウント & ルートリダイレクト
417
+ # ===============================================================
418
+ GRADIO_PATH = "/gradio"
419
+ app = gr.mount_gradio_app(app, iface, path=GRADIO_PATH, ssr_mode=False)
420
+
421
+
422
+ @app.get("/")
423
+ def _root():
424
+ """ルートに来たら /gradio へリダイレクト"""
425
+ return RedirectResponse(GRADIO_PATH)
426
+
427
+
428
+ # ===============================================================
429
+ # ローカルデバッグ用
430
+ # ===============================================================
431
  if __name__ == "__main__":
432
  import uvicorn
433
+ logger.info("Uvicorn 起動 (ローカル)")
434
  uvicorn.run(app, host="0.0.0.0", port=7860)
435
 
436
+
437
+ # ===============================================================
438
+ # 終了時 WebDriver 後始末
439
+ # ===============================================================
440
  import atexit
441
  atexit.register(driver_pool.close_all)