tomo2chin2 commited on
Commit
26eec55
·
verified ·
1 Parent(s): 5a6269f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -344
app.py CHANGED
@@ -1,16 +1,10 @@
1
  # ===============================================================
2
- # app.pyGradio 5.x + FastAPI + Gemini + Selenium
3
- # ・Gradio 5.29.0 以上を前提に最適化
4
- # ・UI/API を維持したまま 307 ループを解消
5
- # ・/gradio にサブマウントし / はリダイレクトのみ
6
  # ===============================================================
7
 
8
- import os
9
- import time
10
- import tempfile
11
- import logging
12
- import threading
13
- import queue
14
  from io import BytesIO
15
  from concurrent.futures import ThreadPoolExecutor
16
 
@@ -32,97 +26,54 @@ from selenium.webdriver.support import expected_conditions as EC
32
  import google.generativeai as genai
33
  from huggingface_hub import hf_hub_download
34
 
35
-
36
- # ===============================================================
37
- # ロガー
38
- # ===============================================================
39
  logging.basicConfig(level=logging.INFO)
40
  logger = logging.getLogger(__name__)
41
 
42
-
43
- # ===============================================================
44
- # WebDriver プール
45
- # ===============================================================
46
  class WebDriverPool:
47
- """複数の WebDriver を使い回すシンプルなプール"""
48
  def __init__(self, max_drivers: int = 3):
49
- self.driver_queue: "queue.Queue[webdriver.Chrome]" = queue.Queue()
50
  self.max_drivers = max_drivers
51
  self.lock = threading.Lock()
52
  self.count = 0
53
  logger.info(f"WebDriver プール初期化: 最大 {max_drivers}")
54
 
55
- def _create_driver(self) -> webdriver.Chrome:
56
- options = Options()
57
- options.add_argument("--headless")
58
- options.add_argument("--no-sandbox")
59
- options.add_argument("--disable-dev-shm-usage")
60
- options.add_argument("--force-device-scale-factor=1")
61
- options.add_argument("--disable-features=NetworkService")
62
- options.add_argument("--dns-prefetch-disable")
63
-
64
- chromedriver_path = os.environ.get("CHROMEDRIVER_PATH")
65
- if chromedriver_path and os.path.exists(chromedriver_path):
66
- logger.info(f"環境変数 CHROMEDRIVER_PATH を使用: {chromedriver_path}")
67
- service = webdriver.ChromeService(executable_path=chromedriver_path)
68
- return webdriver.Chrome(service=service, options=options)
69
- return webdriver.Chrome(options=options)
70
-
71
- def get_driver(self) -> webdriver.Chrome:
72
  if not self.driver_queue.empty():
73
- logger.info("プールから既存 WebDriver を取得")
74
  return self.driver_queue.get()
75
-
76
  with self.lock:
77
  if self.count < self.max_drivers:
78
  self.count += 1
79
- logger.info(f"新規 WebDriver 作成 ({self.count}/{self.max_drivers})")
80
  return self._create_driver()
81
-
82
- # ここに来るのはプール満杯時
83
- logger.info("WebDriver プール満杯。空きを待機")
84
  return self.driver_queue.get()
85
 
86
- def release_driver(self, driver: webdriver.Chrome):
87
- if driver:
88
- try:
89
- driver.get("about:blank")
90
- driver.execute_script("""
91
- document.documentElement.style.overflow='';
92
- document.body.style.overflow='';
93
- """)
94
- self.driver_queue.put(driver)
95
- logger.info("WebDriver をプールに返却")
96
- except Exception as e:
97
- logger.error(f"返却エラー: {e}")
98
- driver.quit()
99
- with self.lock:
100
- self.count -= 1
101
 
102
  def close_all(self):
103
- logger.info("WebDriver 全終了")
104
- closed = 0
105
  while not self.driver_queue.empty():
106
- try:
107
- drv = self.driver_queue.get(block=False)
108
- drv.quit()
109
- closed += 1
110
- except queue.Empty:
111
- break
112
- except Exception as e:
113
- logger.error(f"終了エラー: {e}")
114
- with self.lock:
115
- self.count = 0
116
- logger.info(f"{closed} 個の WebDriver を終了")
117
-
118
 
119
- # グローバルプール
120
- driver_pool = WebDriverPool(max_drivers=int(os.environ.get("MAX_WEBDRIVERS", "3")))
121
 
122
-
123
- # ===============================================================
124
- # Pydantic モデル
125
- # ===============================================================
126
  class GeminiRequest(BaseModel):
127
  text: str
128
  extension_percentage: float = 10.0
@@ -130,312 +81,153 @@ class GeminiRequest(BaseModel):
130
  trim_whitespace: bool = True
131
  style: str = "standard"
132
 
133
-
134
  class ScreenshotRequest(BaseModel):
135
  html_code: str
136
  extension_percentage: float = 10.0
137
  trim_whitespace: bool = True
138
  style: str = "standard"
139
 
140
-
141
- # ===============================================================
142
- # 補助関数
143
- # ===============================================================
144
  def enhance_font_awesome_layout(html_code: str) -> str:
145
- """Font Awesome をプリロード + レイアウト微調整 CSS を注入"""
146
- fa_preload = """
147
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
148
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
149
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-brands-400.woff2" as="font" type="font/woff2" crossorigin>
150
  """
151
- fa_css = """
152
- <style>
153
- [class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;}
154
- h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{vertical-align:middle!important;margin-right:10px!important;}
155
- .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far+span{display:inline-block!important;margin-left:5px!important;}
156
- .card [class*="fa-"],.card-body [class*="fa-"]{float:none!important;clear:none!important;position:relative!important;}
157
- li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;}
158
- .inline-icon{display:inline-flex!important;align-items:center!important;justify-content:flex-start!important;}
159
- [class*="fa-"]+span{display:inline-block!important;vertical-align:middle!important;}
160
- </style>
161
- """
162
  if '<head>' in html_code:
163
- return html_code.replace('</head>', f'{fa_preload}{fa_css}</head>')
164
- elif '<html' in html_code:
165
- head_end = html_code.find('</head>')
166
- if head_end > 0:
167
- return html_code[:head_end] + fa_preload + fa_css + html_code[head_end:]
168
- body_start = html_code.find('<body')
169
- if body_start > 0:
170
- return html_code[:body_start] + f'<head>{fa_preload}{fa_css}</head>' + html_code[body_start:]
171
- return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'
172
-
173
-
174
- def load_system_instruction(style: str = "standard") -> str:
175
- """スタイル別の prompt.txt をローカル or HF から読み込み"""
176
- valid = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
177
- if style not in valid:
178
- logger.warning(f"無効 style '{style}' → 'standard'")
179
- style = "standard"
180
-
181
- local_path = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
182
- if os.path.exists(local_path):
183
- with open(local_path, encoding="utf-8") as f:
184
- return f.read()
185
 
 
 
 
 
 
 
186
  try:
187
- file_path = hf_hub_download(
188
- repo_id="tomo2chin2/GURAREKOstlyle",
189
- filename=f"{style}/prompt.txt",
190
- repo_type="dataset"
191
- )
192
- with open(file_path, encoding="utf-8") as f:
193
- return f.read()
194
- except Exception as e:
195
- logger.warning(f"HF 取得失敗 ({e}) → デフォルト prompt.txt")
196
- file_path = hf_hub_download(
197
- repo_id="tomo2chin2/GURAREKOstlyle",
198
- filename="prompt.txt",
199
- repo_type="dataset"
200
- )
201
- with open(file_path, encoding="utf-8") as f:
202
- return f.read()
203
-
204
-
205
- def generate_html_from_text(text: str, temperature: float = 0.5, style: str = "standard") -> str:
206
- """Gemini で与えられたテキストを HTML に整形して返す"""
207
- api_key = os.environ.get("GEMINI_API_KEY")
208
- if not api_key:
209
- raise ValueError("環境変数 GEMINI_API_KEY が未設定")
210
- model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
211
-
212
- genai.configure(api_key=api_key)
213
- model = genai.GenerativeModel(model_name)
214
-
215
- system_instruction = load_system_instruction(style)
216
- prompt = f"{system_instruction}\n\n{text}"
217
-
218
- generation_config = dict(
219
- temperature=temperature, top_p=0.7, top_k=20,
220
- max_output_tokens=8192, candidate_count=1
221
- )
222
- safety_settings = [
223
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
224
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
225
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
226
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}
227
- ]
228
- resp = model.generate_content(prompt, generation_config=generation_config, safety_settings=safety_settings)
229
- raw = resp.text
230
-
231
- # ```html … ``` 抜き取り
232
- start = raw.find("```html")
233
- end = raw.rfind("```")
234
- html_code = raw[start + 7:end].strip() if start != -1 and end != -1 else raw
235
- return enhance_font_awesome_layout(html_code)
236
-
237
-
238
- def trim_image_whitespace(image: Image.Image, threshold: int = 250, padding: int = 10) -> Image.Image:
239
- """白背景をトリミング"""
240
- gray = image.convert("L")
241
- arr = np.array(gray)
242
- mask = arr < threshold
243
- rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
244
- if np.any(rows) and np.any(cols):
245
- y_min, y_max = np.where(rows)[0][[0, -1]]
246
- x_min, x_max = np.where(cols)[0][[0, -1]]
247
- return image.crop((
248
- max(0, x_min - padding), max(0, y_min - padding),
249
- min(image.width - 1, x_max + padding), min(image.height - 1, y_max + padding)
250
- ))
251
- return image
252
-
253
 
254
- # ===============================================================
255
- # HTML スクリーンショット
256
- # ===============================================================
257
- def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0,
258
- trim_whitespace: bool = True, driver=None) -> Image.Image:
259
- tmp_path = None
260
- driver_from_pool = False
261
  try:
262
  if driver is None:
263
- driver = driver_pool.get_driver()
264
- driver_from_pool = True
265
-
266
- # 一時 HTML 保存
267
- with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
268
- tmp_path = tmp.name
269
- tmp.write(html_code)
270
-
271
- driver.set_window_size(1200, 1000)
272
- driver.get("file://" + tmp_path)
273
-
274
- # body 待機
275
- WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
276
-
277
- total_height = driver.execute_script(
278
- "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
279
- viewport_height = driver.execute_script("return window.innerHeight")
280
- for i in range(max(1, min(5, total_height // viewport_height))):
281
- driver.execute_script(f"window.scrollTo(0, {i * (viewport_height - 100)})")
282
- time.sleep(0.1)
283
- driver.execute_script("window.scrollTo(0,0)")
284
- time.sleep(0.2)
285
-
286
- dims = driver.execute_script("""
287
- return {
288
- width: Math.max(document.body.scrollWidth, document.documentElement.scrollWidth),
289
- height: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)
290
- }
291
- """)
292
- width = min(max(dims["width"], 100), 2000)
293
- height = min(max(dims["height"], 100), 4000)
294
- height = int(height * (1 + extension_percentage / 100.0))
295
- driver.set_window_size(width, height)
296
- time.sleep(0.5)
297
-
298
- img = Image.open(BytesIO(driver.get_screenshot_as_png()))
299
- return trim_image_whitespace(img, 248, 20) if trim_whitespace else img
300
  except Exception as e:
301
- logger.error(f"Screenshot error: {e}", exc_info=True)
302
- return Image.new("RGB", (1, 1), (0, 0, 0))
303
  finally:
304
- if driver_from_pool:
305
- driver_pool.release_driver(driver)
306
- if tmp_path and os.path.exists(tmp_path):
307
- try:
308
- os.remove(tmp_path)
309
- except Exception:
310
- pass
311
-
312
 
313
- # ===============================================================
314
- # テキスト スクリーンショット(並列)
315
- # ===============================================================
316
- def text_to_screenshot_parallel(text: str, extension_percentage: float, temperature: float = 0.5,
317
- trim_whitespace: bool = True, style: str = "standard") -> Image.Image:
318
- start = time.time()
319
  with ThreadPoolExecutor(max_workers=2) as ex:
320
- html_future = ex.submit(generate_html_from_text, text, temperature, style)
321
- driver_future = ex.submit(driver_pool.get_driver)
322
-
323
- html_code = html_future.result()
324
- driver = driver_future.result()
325
- img = render_fullpage_screenshot(html_code, extension_percentage, trim_whitespace, driver)
326
- logger.info(f"並列生成完了: {time.time() - start:.2f}s")
327
- return img
328
-
329
-
330
- def text_to_screenshot(*args, **kwargs):
331
- """後方互換用エイリアス"""
332
- return text_to_screenshot_parallel(*args, **kwargs)
333
-
334
 
335
  # ===============================================================
336
- # FastAPI
337
  # ===============================================================
338
- app = FastAPI()
339
- app.add_middleware(
340
- CORSMiddleware,
341
- allow_origins=["*"], allow_credentials=True,
342
- allow_methods=["*"], allow_headers=["*"]
343
- )
344
 
 
 
 
 
345
 
346
- @app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
347
- async def api_render_screenshot(req: ScreenshotRequest):
348
- img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
349
- buf = BytesIO(); img.save(buf, "PNG"); buf.seek(0)
350
- return StreamingResponse(buf, media_type="image/png")
351
-
352
-
353
- @app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Screenshot", "Gemini"])
354
- async def api_text_to_screenshot(req: GeminiRequest):
355
- img = text_to_screenshot_parallel(
356
- req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
357
- buf = BytesIO(); img.save(buf, "PNG"); buf.seek(0)
358
- return StreamingResponse(buf, media_type="image/png")
359
-
360
 
361
  # ===============================================================
362
  # Gradio UI
363
  # ===============================================================
364
- def process_input(mode, text, ext, temp, trim, style):
365
- if mode == "HTML入力":
366
- return render_fullpage_screenshot(text, ext, trim)
367
- return text_to_screenshot_parallel(text, ext, temp, trim, style)
368
-
369
 
370
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)",
371
- theme=gr.themes.Origin()) as iface:
372
- gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
373
- with gr.Row():
374
- input_mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード")
375
-
376
- input_text = gr.Textbox(lines=15, label="入力")
377
-
378
  with gr.Row():
379
- style_dropdown = gr.Dropdown(
380
- ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
381
- value="standard", label="デザインスタイル", visible=False)
382
  with gr.Column(scale=2):
383
- ext_slider = gr.Slider(0, 30, value=10, step=1, label="上下高さ拡張率(%)")
384
- temp_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.1,
385
- label="生成時の温度", visible=False)
 
 
386
 
387
- trim_chk = gr.Checkbox(value=True, label="余白を自動トリミング")
 
388
 
389
- submit_btn = gr.Button("生成")
390
- out_img = gr.Image(type="pil", label="スクリーンショット")
391
-
392
- def toggle(mode):
393
- is_text = mode == "テキスト入力"
394
- return [gr.update(visible=is_text), gr.update(visible=is_text)]
395
- input_mode.change(toggle, input_mode, [temp_slider, style_dropdown])
396
-
397
- submit_btn.click(
398
- process_input,
399
- [input_mode, input_text, ext_slider, temp_slider, trim_chk, style_dropdown],
400
- out_img)
401
-
402
- model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
403
- gr.Markdown(f"""
404
- **API**
405
- - `/api/screenshot` – HTML → PNG
406
- - `/api/text-to-screenshot` – テキスト → インフォグラフィック PNG
407
-
408
- **設定**
409
- - 使用モデル: `{model_name}`
410
- - 対応スタイル: standard / cute / resort / cool / dental / school / KOKUGO
411
- - WebDriver 最大数: {driver_pool.max_drivers}
412
- """)
413
-
414
-
415
- # ===============================================================
416
- # FastAPI へマウント & ルートリダイレクト
417
- # ===============================================================
418
- GRADIO_PATH = "/gradio"
419
- app = gr.mount_gradio_app(app, iface, path=GRADIO_PATH, ssr_mode=False)
420
 
 
 
 
 
 
421
 
 
422
  @app.get("/")
423
- def _root():
424
- """ルートに来たら /gradio へリダイレクト"""
425
- return RedirectResponse(GRADIO_PATH)
426
 
 
 
 
427
 
428
- # ===============================================================
429
- # ローカルデバッグ用
430
- # ===============================================================
431
- if __name__ == "__main__":
432
- import uvicorn
433
  logger.info("Uvicorn 起動 (ローカル)")
434
- uvicorn.run(app, host="0.0.0.0", port=7860)
435
 
436
-
437
- # ===============================================================
438
- # 終了時 WebDriver 後始末
439
- # ===============================================================
440
- import atexit
441
- atexit.register(driver_pool.close_all)
 
1
  # ===============================================================
2
+ # app.pyGradio 5.x + FastAPI + Gemini + Selenium
3
+ # FastAPI(redirect_slashes=False) 307 ループを根絶
4
+ # サブアプリを /gradio にマウントし、自前でリダイレクト制御
 
5
  # ===============================================================
6
 
7
+ import os, time, tempfile, logging, threading, queue
 
 
 
 
 
8
  from io import BytesIO
9
  from concurrent.futures import ThreadPoolExecutor
10
 
 
26
  import google.generativeai as genai
27
  from huggingface_hub import hf_hub_download
28
 
29
+ # ---------- ログ ----------
 
 
 
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
33
+ # ---------- WebDriver プール(前回と同じ) ----------
 
 
 
34
  class WebDriverPool:
 
35
  def __init__(self, max_drivers: int = 3):
36
+ self.driver_queue = queue.Queue()
37
  self.max_drivers = max_drivers
38
  self.lock = threading.Lock()
39
  self.count = 0
40
  logger.info(f"WebDriver プール初期化: 最大 {max_drivers}")
41
 
42
+ def _create_driver(self):
43
+ opts = Options()
44
+ opts.add_argument("--headless"); opts.add_argument("--no-sandbox")
45
+ opts.add_argument("--disable-dev-shm-usage")
46
+ drv_path = os.getenv("CHROMEDRIVER_PATH")
47
+ if drv_path and os.path.exists(drv_path):
48
+ return webdriver.Chrome(service=webdriver.ChromeService(executable_path=drv_path), options=opts)
49
+ return webdriver.Chrome(options=opts)
50
+
51
+ def get_driver(self):
 
 
 
 
 
 
 
52
  if not self.driver_queue.empty():
 
53
  return self.driver_queue.get()
 
54
  with self.lock:
55
  if self.count < self.max_drivers:
56
  self.count += 1
 
57
  return self._create_driver()
 
 
 
58
  return self.driver_queue.get()
59
 
60
+ def release_driver(self, driver):
61
+ try:
62
+ driver.get("about:blank")
63
+ self.driver_queue.put(driver)
64
+ except Exception:
65
+ driver.quit()
66
+ with self.lock: self.count -= 1
 
 
 
 
 
 
 
 
67
 
68
  def close_all(self):
 
 
69
  while not self.driver_queue.empty():
70
+ try: self.driver_queue.get(block=False).quit()
71
+ except queue.Empty: break
72
+ with self.lock: self.count = 0
 
 
 
 
 
 
 
 
 
73
 
74
+ driver_pool = WebDriverPool(max_drivers=int(os.getenv("MAX_WEBDRIVERS", "3")))
 
75
 
76
+ # ---------- Pydantic モデル ----------
 
 
 
77
  class GeminiRequest(BaseModel):
78
  text: str
79
  extension_percentage: float = 10.0
 
81
  trim_whitespace: bool = True
82
  style: str = "standard"
83
 
 
84
  class ScreenshotRequest(BaseModel):
85
  html_code: str
86
  extension_percentage: float = 10.0
87
  trim_whitespace: bool = True
88
  style: str = "standard"
89
 
90
+ # ---------- 補助関数(前回と同じ:FontAwesome, prompt 取得, Gemini, トリミング) ----------
 
 
 
91
  def enhance_font_awesome_layout(html_code: str) -> str:
92
+ preload = """
 
93
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
94
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
95
  <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-brands-400.woff2" as="font" type="font/woff2" crossorigin>
96
  """
97
+ css = '<style>[class*="fa-"]{display:inline-block;margin-right:8px;vertical-align:middle}</style>'
 
 
 
 
 
 
 
 
 
 
98
  if '<head>' in html_code:
99
+ return html_code.replace('</head>', f'{preload}{css}</head>')
100
+ return f'<html><head>{preload}{css}</head>{html_code}</html>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ def load_system_instruction(style="standard"):
103
+ if style not in ["standard","cute","resort","cool","dental","school","KOKUGO"]:
104
+ style = "standard"
105
+ local = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
106
+ if os.path.exists(local):
107
+ return open(local, encoding="utf-8").read()
108
  try:
109
+ path = hf_hub_download("tomo2chin2/GURAREKOstlyle", f"{style}/prompt.txt", repo_type="dataset")
110
+ return open(path, encoding="utf-8").read()
111
+ except Exception:
112
+ path = hf_hub_download("tomo2chin2/GURAREKOstlyle", "prompt.txt", repo_type="dataset")
113
+ return open(path, encoding="utf-8").read()
114
+
115
+ def generate_html_from_text(text, temperature=0.5, style="standard"):
116
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
117
+ model = genai.GenerativeModel(os.getenv("GEMINI_MODEL","gemini-1.5-pro"))
118
+ prompt = f"{load_system_instruction(style)}\n\n{text}"
119
+ cfg=dict(temperature=temperature,top_p=0.7,top_k=20,max_output_tokens=8192,candidate_count=1)
120
+ raw=model.generate_content(prompt,generation_config=cfg).text
121
+ s,e=raw.find("```html"),raw.rfind("```")
122
+ html=raw[s+7:e].strip() if s!=-1 and e!=-1 else raw
123
+ return enhance_font_awesome_layout(html)
124
+
125
+ def trim_image_whitespace(img:Image.Image,th=248,pad=20)->Image.Image:
126
+ arr=np.array(img.convert("L")); m=arr<th
127
+ if m.any():
128
+ y,x=np.where(m.any(1))[0],np.where(m.any(0))[0]
129
+ return img.crop((max(0,x[0]-pad),max(0,y[0]-pad),
130
+ min(img.width-1,x[-1]+pad),min(img.height-1,y[-1]+pad)))
131
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # ---------- HTML → PNG ----------
134
+ def render_fullpage_screenshot(html, ext=6.0, trim_ws=True, driver=None):
135
+ tmp=None; from_pool=False
 
 
 
 
136
  try:
137
  if driver is None:
138
+ driver=driver_pool.get_driver(); from_pool=True
139
+ with tempfile.NamedTemporaryFile(suffix=".html",delete=False,mode="w",encoding="utf-8") as f:
140
+ tmp=f.name; f.write(html)
141
+ driver.set_window_size(1200,1000)
142
+ driver.get("file://"+tmp)
143
+ WebDriverWait(driver,10).until(EC.presence_of_element_located((By.TAG_NAME,"body")))
144
+ h=driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
145
+ vh=driver.execute_script("return window.innerHeight")
146
+ for i in range(max(1,min(5,h//vh))):
147
+ driver.execute_script(f"window.scrollTo(0,{i*(vh-100)})"); time.sleep(0.1)
148
+ driver.execute_script("window.scrollTo(0,0)"); time.sleep(0.2)
149
+ dims=driver.execute_script("return {w:Math.max(document.body.scrollWidth,document.documentElement.scrollWidth),h:Math.max(document.body.scrollHeight,document.documentElement.scrollHeight)}")
150
+ w=min(max(dims['w'],100),2000); h=min(max(dims['h'],100),4000); h=int(h*(1+ext/100))
151
+ driver.set_window_size(w,h); time.sleep(0.4)
152
+ img=Image.open(BytesIO(driver.get_screenshot_as_png()))
153
+ return trim_image_whitespace(img,248,20) if trim_ws else img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  except Exception as e:
155
+ logger.error(f"Screenshot err: {e}",exc_info=True)
156
+ return Image.new("RGB",(1,1),(0,0,0))
157
  finally:
158
+ if from_pool: driver_pool.release_driver(driver)
159
+ if tmp and os.path.exists(tmp): os.remove(tmp)
 
 
 
 
 
 
160
 
161
+ # ---------- テキスト → PNG ----------
162
+ def text_to_screenshot_parallel(text, ext, temp=0.5, trim_ws=True, style="standard"):
 
 
 
 
163
  with ThreadPoolExecutor(max_workers=2) as ex:
164
+ html=ex.submit(generate_html_from_text,text,temp,style).result()
165
+ drv=ex.submit(driver_pool.get_driver).result()
166
+ return render_fullpage_screenshot(html,ext,trim_ws,drv)
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # ===============================================================
169
+ # FastAPI (★ リダイレクト無効化)
170
  # ===============================================================
171
+ app = FastAPI(redirect_slashes=False)
172
+ app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,
173
+ allow_methods=["*"],allow_headers=["*"])
 
 
 
174
 
175
+ @app.post("/api/screenshot",response_class=StreamingResponse,tags=["Screenshot"])
176
+ async def api_screenshot(r:ScreenshotRequest):
177
+ img=render_fullpage_screenshot(r.html_code,r.extension_percentage,r.trim_whitespace)
178
+ buf=BytesIO(); img.save(buf,"PNG"); buf.seek(0); return StreamingResponse(buf,media_type="image/png")
179
 
180
+ @app.post("/api/text-to-screenshot",response_class=StreamingResponse,tags=["Gemini","Screenshot"])
181
+ async def api_text_to_screenshot(r:GeminiRequest):
182
+ img=text_to_screenshot_parallel(r.text,r.extension_percentage,r.temperature,r.trim_whitespace,r.style)
183
+ buf=BytesIO(); img.save(buf,"PNG"); buf.seek(0); return StreamingResponse(buf,media_type="image/png")
 
 
 
 
 
 
 
 
 
 
184
 
185
  # ===============================================================
186
  # Gradio UI
187
  # ===============================================================
188
+ def process(mode,txt,ext,temp,trim,style):
189
+ return render_fullpage_screenshot(txt,ext,trim) if mode=="HTML入力" else \
190
+ text_to_screenshot_parallel(txt,ext,temp,trim,style)
 
 
191
 
192
  with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)",
193
+ theme=gr.themes.Origin()) as gui:
194
+ gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック")
195
+ mode=gr.Radio(["HTML入力","テキスト入力"],value="HTML入力",label="入力モード")
196
+ txt=gr.Textbox(lines=15,label="入力")
 
 
 
197
  with gr.Row():
198
+ style=gr.Dropdown(["standard","cute","resort","cool","dental","school","KOKUGO"],
199
+ value="standard",label="デザインスタイル",visible=False)
 
200
  with gr.Column(scale=2):
201
+ ext=gr.Slider(0,30,10,1,label="上下高さ拡張率(%)")
202
+ temp=gr.Slider(0.0,1.0,0.5,0.1,label="生成時の温度",visible=False)
203
+ trim=gr.Checkbox(True,label="余白を自動トリミング")
204
+ btn=gr.Button("生成")
205
+ out=gr.Image(type="pil",label="スクリーンショット")
206
 
207
+ mode.change(lambda m:[gr.update(visible=m=="テキスト入力")]*2,mode,[temp,style])
208
+ btn.click(process,[mode,txt,ext,temp,trim,style],out)
209
 
210
+ gr.Markdown(f"**API** `/api/screenshot`, `/api/text-to-screenshot` \n使用モデル: `{os.getenv('GEMINI_MODEL','gemini-1.5-pro')}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ # ---------------------------------------------------------------
213
+ # Gradio を /gradio にマウント
214
+ # ---------------------------------------------------------------
215
+ GRADIO_BASE="/gradio"
216
+ app = gr.mount_gradio_app(app, gui, path=GRADIO_BASE, ssr_mode=False)
217
 
218
+ # ルート → /gradio/ へワンショット転送
219
  @app.get("/")
220
+ def _root(): return RedirectResponse(GRADIO_BASE+"/")
 
 
221
 
222
+ # /gradio → /gradio/ にも転送して 307 を回避
223
+ @app.get(GRADIO_BASE)
224
+ def _no_slash(): return RedirectResponse(GRADIO_BASE+"/")
225
 
226
+ # ローカルデバッグ
227
+ if __name__=="__main__":
228
+ import uvicorn, httpx
 
 
229
  logger.info("Uvicorn 起動 (ローカル)")
230
+ uvicorn.run(app,host="0.0.0.0",port=7860)
231
 
232
+ # 終了クリーンアップ
233
+ import atexit; atexit.register(driver_pool.close_all)