Spaces:
Paused
Paused
| # =============================================================== | |
| # app.py ― Gradio 5.x + FastAPI + Gemini + Selenium + 307対策 | |
| # (1) FastAPI(redirect_slashes=False) ←★追加 | |
| # (2) Gradio を /gradio にマウント | |
| # (3) / と /gradio を /gradio/ へリダイレクト ←★追加 | |
| # それ以外は 5.x 対応フルロジックを一切カットせず | |
| # =============================================================== | |
| import os, time, tempfile, logging, threading, queue | |
| from io import BytesIO | |
| from concurrent.futures import ThreadPoolExecutor | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse, RedirectResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.options import Options | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.support.ui import WebDriverWait | |
| from selenium.webdriver.support import expected_conditions as EC | |
| # Updated: Use the new Google Genai library | |
| from google import genai | |
| from google.genai import types | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------- | |
| # ロギング | |
| # --------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------- | |
| # WebDriverPool ― 以前提示した完全版と同一 | |
| # --------------------------------------------------------------- | |
| class WebDriverPool: | |
| """再利用可能な WebDriver をプールして高速化""" | |
| def __init__(self, max_drivers: int = 3): | |
| self.driver_queue = queue.Queue() | |
| self.max_drivers = max_drivers | |
| self.lock = threading.Lock() | |
| self.count = 0 | |
| logger.info(f"WebDriver プール初期化: 最大 {max_drivers}") | |
| def _create_driver(self): | |
| options = Options() | |
| options.add_argument("--headless") | |
| options.add_argument("--no-sandbox") | |
| options.add_argument("--disable-dev-shm-usage") | |
| options.add_argument("--force-device-scale-factor=1") | |
| options.add_argument("--disable-features=NetworkService") | |
| options.add_argument("--dns-prefetch-disable") | |
| chromedriver_path = os.environ.get("CHROMEDRIVER_PATH") | |
| if chromedriver_path and os.path.exists(chromedriver_path): | |
| logger.info(f"CHROMEDRIVER_PATH 使用: {chromedriver_path}") | |
| service = webdriver.ChromeService(executable_path=chromedriver_path) | |
| return webdriver.Chrome(service=service, options=options) | |
| return webdriver.Chrome(options=options) | |
| def get_driver(self): | |
| if not self.driver_queue.empty(): | |
| logger.info("既存 WebDriver 取得") | |
| return self.driver_queue.get() | |
| with self.lock: | |
| if self.count < self.max_drivers: | |
| self.count += 1 | |
| logger.info(f"新規 WebDriver 作成 ({self.count}/{self.max_drivers})") | |
| return self._create_driver() | |
| logger.info("プール満杯、空き待機中…") | |
| return self.driver_queue.get() | |
| def release_driver(self, driver): | |
| if driver: | |
| try: | |
| driver.get("about:blank") | |
| driver.execute_script(""" | |
| document.documentElement.style.overflow=''; | |
| document.body.style.overflow=''; | |
| """) | |
| self.driver_queue.put(driver) | |
| logger.info("WebDriver をプールに返却") | |
| except Exception as e: | |
| logger.error(f"返却エラー: {e}") | |
| driver.quit() | |
| with self.lock: | |
| self.count -= 1 | |
| def close_all(self): | |
| logger.info("プール全 WebDriver 終了") | |
| closed = 0 | |
| while not self.driver_queue.empty(): | |
| try: | |
| drv = self.driver_queue.get(block=False) | |
| drv.quit(); closed += 1 | |
| except queue.Empty: | |
| break | |
| except Exception as e: | |
| logger.error(f"終了エラー: {e}") | |
| with self.lock: self.count = 0 | |
| logger.info(f"{closed} 個終了") | |
| driver_pool = WebDriverPool(max_drivers=int(os.getenv("MAX_WEBDRIVERS", "3"))) | |
| # --------------------------------------------------------------- | |
| # Pydantic モデル | |
| # --------------------------------------------------------------- | |
| class GeminiRequest(BaseModel): | |
| text: str | |
| extension_percentage: float = 10.0 | |
| temperature: float = 0.5 | |
| trim_whitespace: bool = True | |
| style: str = "standard" | |
| class ScreenshotRequest(BaseModel): | |
| html_code: str | |
| extension_percentage: float = 10.0 | |
| trim_whitespace: bool = True | |
| style: str = "standard" | |
| # --------------------------------------------------------------- | |
| # 補助関数(FontAwesome レイアウト / prompt 読み込み / Gemini 生成) | |
| # --------------------------------------------------------------- | |
| def enhance_font_awesome_layout(html_code: str) -> str: | |
| fa_preload = """ | |
| <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin> | |
| <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin> | |
| <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-brands-400.woff2" as="font" type="font/woff2" crossorigin> | |
| """ | |
| fa_css = """ | |
| <style> | |
| [class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;} | |
| h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{vertical-align:middle!important;margin-right:10px!important;} | |
| .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far+span{display:inline-block!important;margin-left:5px!important;} | |
| .card [class*="fa-"],.card-body [class*="fa-"]{float:none!important;clear:none!important;position:relative!important;} | |
| li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;} | |
| .inline-icon{display:inline-flex!important;align-items:center!important;justify-content:flex-start!important;} | |
| [class*="fa-"]+span{display:inline-block!important;vertical-align:middle!important;} | |
| </style> | |
| """ | |
| if '<head>' in html_code: | |
| return html_code.replace('</head>', f'{fa_preload}{fa_css}</head>') | |
| return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>' | |
| # シンプルなプロンプトキャッシュを実装 | |
| _prompt_cache = {} | |
| def load_system_instruction(style="standard") -> str: | |
| # キャッシュに存在すればそれを返す | |
| if style in _prompt_cache: | |
| return _prompt_cache[style] | |
| valid_styles = ["standard","cute","resort","cool","dental","school","KOKUGO"] | |
| if style not in valid_styles: | |
| style = "standard" | |
| local = os.path.join(os.path.dirname(__file__), style, "prompt.txt") | |
| if os.path.exists(local): | |
| prompt_text = open(local, encoding="utf-8").read() | |
| else: | |
| try: | |
| f = hf_hub_download("tomo2chin2/GURAREKOstlyle", f"{style}/prompt.txt", repo_type="dataset") | |
| prompt_text = open(f, encoding="utf-8").read() | |
| except Exception: | |
| f = hf_hub_download("tomo2chin2/GURAREKOstlyle", "prompt.txt", repo_type="dataset") | |
| prompt_text = open(f, encoding="utf-8").read() | |
| # キャッシュに保存 | |
| _prompt_cache[style] = prompt_text | |
| return prompt_text | |
| def generate_html_from_text(text: str, temperature=0.5, style="standard") -> str: | |
| # Updated: Use the new Google Genai client API | |
| api_key = os.environ["GEMINI_API_KEY"] | |
| client = genai.Client(api_key=api_key) | |
| model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-pro") | |
| prompt = f"{load_system_instruction(style)}\n\n{text}" | |
| # Configure generation parameters | |
| config = types.GenerateContentConfig( | |
| temperature=temperature, | |
| top_p=0.7, | |
| top_k=20, | |
| max_output_tokens=8192, | |
| candidate_count=1 | |
| ) | |
| # Gemini 2.5 Flash Preview モデルの場合の特別設定 | |
| if model_name == "gemini-2.5-flash-preview-04-17": | |
| logger.info("gemini-2.5-flash-preview-04-17 モデル検出: 思考モードをオフに設定") | |
| config.thinking_config = types.ThinkingConfig(thinking_budget=0) | |
| # max_output_tokens を 50000 に拡張 (唯一追加した最適化) | |
| logger.info("gemini-2.5-flash-preview-04-17 モデル検出: max_output_tokens を 50000 に設定") | |
| config.max_output_tokens = 10000 | |
| # Generate content | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config=config | |
| ) | |
| # Extract HTML from response | |
| raw = response.text | |
| s, e = raw.find("```html"), raw.rfind("```") | |
| html = raw[s+7:e].strip() if s != -1 and e != -1 else raw | |
| return enhance_font_awesome_layout(html) | |
| def trim_image_whitespace(img: Image.Image, threshold=248, padding=20) -> Image.Image: | |
| arr = np.array(img.convert("L")) | |
| mask = arr < threshold | |
| if np.any(mask): | |
| ys, xs = np.where(mask.any(1))[0], np.where(mask.any(0))[0] | |
| return img.crop((max(xs[0]-padding,0), max(ys[0]-padding,0), | |
| min(xs[-1]+padding, img.width-1), | |
| min(ys[-1]+padding, img.height-1))) | |
| return img | |
| # --------------------------------------------------------------- | |
| # HTML → スクショ (完全版ロジック) | |
| # --------------------------------------------------------------- | |
| def render_fullpage_screenshot(html_code: str, extension_percentage=6.0, | |
| trim_whitespace=True, driver=None) -> Image.Image: | |
| tmp_path = None | |
| from_pool = False | |
| try: | |
| if driver is None: | |
| driver = driver_pool.get_driver() | |
| from_pool = True | |
| # HTML 保存 | |
| with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp: | |
| tmp_path = tmp.name | |
| tmp.write(html_code) | |
| driver.set_window_size(1200, 1000) | |
| driver.get("file://" + tmp_path) | |
| # body 出現を待機 | |
| WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) | |
| # リソースロード確認ループ(詳細ロジックは元コード準拠) | |
| max_wait, inc, waited = 5, 0.2, 0.0 | |
| while waited < max_wait: | |
| state = driver.execute_script(""" | |
| return {complete: document.readyState==='complete', | |
| imgs: document.images.length, | |
| loaded: Array.from(document.images).filter(i=>i.complete).length}; | |
| """) | |
| if state['complete'] and (state['imgs']==0 or state['imgs']==state['loaded']): | |
| break | |
| time.sleep(inc); waited += inc | |
| # スクロールレンダリング | |
| total_h = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)") | |
| vh = driver.execute_script("return window.innerHeight") | |
| for i in range(max(1, min(5, total_h // vh))): | |
| driver.execute_script(f"window.scrollTo(0, {(vh-100)*i})") | |
| time.sleep(0.1) | |
| driver.execute_script("window.scrollTo(0,0)"); time.sleep(0.2) | |
| dims = driver.execute_script(""" | |
| return {w: Math.max(document.body.scrollWidth, document.documentElement.scrollWidth), | |
| h: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)} | |
| """) | |
| w = min(max(dims['w'], 100), 2000) | |
| h = min(max(dims['h'], 100), 4000) | |
| h = int(h * (1 + extension_percentage / 100.0)) | |
| driver.set_window_size(w, h); time.sleep(0.5) | |
| img = Image.open(BytesIO(driver.get_screenshot_as_png())) | |
| return trim_image_whitespace(img, padding=20) if trim_whitespace else img | |
| except Exception as e: | |
| logger.error(f"Screenshot error: {e}", exc_info=True) | |
| return Image.new("RGB", (1, 1), (0, 0, 0)) | |
| finally: | |
| if from_pool: | |
| driver_pool.release_driver(driver) | |
| if tmp_path and os.path.exists(tmp_path): | |
| try: os.remove(tmp_path) | |
| except Exception: pass | |
| # --------------------------------------------------------------- | |
| # テキスト → スクショ (並列 API 呼び出し + ドライバ確保) | |
| # --------------------------------------------------------------- | |
| def text_to_screenshot_parallel(text, ext_perc, temp=0.5, trim_ws=True, style="standard") -> Image.Image: | |
| with ThreadPoolExecutor(max_workers=2) as exe: | |
| html_future = exe.submit(generate_html_from_text, text, temp, style) | |
| driver_future = exe.submit(driver_pool.get_driver) | |
| html_code = html_future.result() | |
| driver = driver_future.result() | |
| return render_fullpage_screenshot(html_code, ext_perc, trim_ws, driver) | |
| def text_to_screenshot(*args, **kwargs): | |
| return text_to_screenshot_parallel(*args, **kwargs) | |
| # =============================================================== | |
| # FastAPI (★ redirect_slashes=False で自動 307 を殺す) | |
| # =============================================================== | |
| app = FastAPI(redirect_slashes=False) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------- API エンドポイントはそのまま -------- | |
| async def api_render_screenshot(req: ScreenshotRequest): | |
| img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace) | |
| buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| async def api_text_to_screenshot(req: GeminiRequest): | |
| img = text_to_screenshot_parallel( | |
| req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style) | |
| buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| # =============================================================== | |
| # Gradio UI (完全版 UI 定義) | |
| # =============================================================== | |
| def process_input(mode, text, ext, temp, trim, style): | |
| return render_fullpage_screenshot(text, ext, trim) if mode == "HTML入力" else \ | |
| text_to_screenshot_parallel(text, ext, temp, trim, style) | |
| with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Origin()) as iface: | |
| gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換") | |
| with gr.Row(): | |
| mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード") | |
| text = gr.Textbox(lines=15, label="入力") | |
| with gr.Row(): | |
| style = gr.Dropdown( | |
| ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"], | |
| value="standard", label="デザインスタイル", visible=False) | |
| with gr.Column(scale=2): | |
| ext = gr.Slider(0, 30, value=15, step=1, label="上下高さ拡張率(%)") | |
| temp = gr.Slider(0.0, 1.0, value=1.0, step=0.1, | |
| label="生成時の温度", visible=False) | |
| trim = gr.Checkbox(value=True, label="余白を自動トリミング") | |
| btn = gr.Button("生成") | |
| out = gr.Image(type="pil", label="スクリーンショット") | |
| # 可視制御 | |
| def toggle(m): vis = m == "テキスト入力"; return [gr.update(visible=vis), gr.update(visible=vis)] | |
| mode.change(toggle, mode, [temp, style]) | |
| btn.click(process_input, [mode, text, ext, temp, trim, style], out) | |
| # モデル名を表示し、Gemini 2.5の場合は思考モードのステータスも表示 | |
| model_name = os.getenv('GEMINI_MODEL', 'gemini-1.5-pro') | |
| thinking_status = "" | |
| if model_name == "gemini-2.5-flash-preview-04-17": | |
| thinking_status = "(思考モード: オフ、最大トークン: 10000)" | |
| gr.Markdown(f"**API** `/api/screenshot`, `/api/text-to-screenshot` " | |
| f"使用モデル: `{model_name}` {thinking_status}") | |
| # =============================================================== | |
| # Gradio を /gradio にマウントし、明示リダイレクトを追加 | |
| # =============================================================== | |
| GRADIO_PATH = "/gradio" | |
| app = gr.mount_gradio_app(app, iface, path=GRADIO_PATH, ssr_mode=False) | |
| # ルート → /gradio/ へ転送 | |
| def _root(): return RedirectResponse(GRADIO_PATH + "/") | |
| # 末尾スラッシュ無し → 有りへ転送 | |
| def _no_slash(): return RedirectResponse(GRADIO_PATH + "/") | |
| # 起動時に頻繁に使用するプロンプトを先読み | |
| async def startup_event(): | |
| # 初期化は最小限に | |
| styles = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"] | |
| for style in styles: | |
| load_system_instruction(style) | |
| logger.info("システムプロンプトのキャッシュを準備完了") | |
| # =============================================================== | |
| # ローカルデバッグ | |
| # =============================================================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Uvicorn 起動 (ローカル)") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| # =============================================================== | |
| # 終了時 WebDriver クリーンアップ | |
| # =============================================================== | |
| import atexit | |
| atexit.register(driver_pool.close_all) |