Spaces:
Paused
Paused
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException, Body | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.options import Options | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.support.ui import WebDriverWait | |
| from selenium.webdriver.support import expected_conditions as EC | |
| from PIL import Image | |
| from io import BytesIO | |
| import tempfile | |
| import time | |
| import os | |
| import logging | |
| # ロギング設定 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Core Screenshot Logic --- | |
| def render_fullpage_screenshot(html_code: str, extension_percentage: float) -> Image.Image: | |
| """ | |
| Renders HTML code to a full-page screenshot using Selenium. | |
| Args: | |
| html_code: The HTML source code string. | |
| extension_percentage: Percentage of extra space to add vertically (e.g., 4 means 4% total). | |
| Returns: | |
| A PIL Image object of the screenshot. Returns a 1x1 black image on error. | |
| """ | |
| tmp_path = None # 初期化 | |
| driver = None # 初期化 | |
| # 1) Save HTML code to a temporary file | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode='w', encoding='utf-8') as tmp_file: | |
| tmp_path = tmp_file.name | |
| tmp_file.write(html_code) | |
| logger.info(f"HTML saved to temporary file: {tmp_path}") | |
| except Exception as e: | |
| logger.error(f"Error writing temporary HTML file: {e}") | |
| return Image.new('RGB', (1, 1), color=(0, 0, 0)) # エラー時は黒画像 | |
| # 2) Headless Chrome(Chromium) options | |
| options = Options() | |
| options.add_argument("--headless") | |
| options.add_argument("--no-sandbox") | |
| options.add_argument("--disable-dev-shm-usage") | |
| options.add_argument("--force-device-scale-factor=1") | |
| # Increase logging verbosity for debugging if needed | |
| # options.add_argument("--enable-logging") | |
| # options.add_argument("--v=1") | |
| try: | |
| logger.info("Initializing WebDriver...") | |
| driver = webdriver.Chrome(options=options) | |
| logger.info("WebDriver initialized.") | |
| # 3) Load page with initial window size | |
| driver.set_window_size(1200, 800) | |
| file_url = "file://" + tmp_path | |
| logger.info(f"Navigating to {file_url}") | |
| driver.get(file_url) | |
| # 4) Wait for page load | |
| logger.info("Waiting for body element...") | |
| WebDriverWait(driver, 15).until( # タイムアウトを少し延長 | |
| EC.presence_of_element_located((By.TAG_NAME, "body")) | |
| ) | |
| logger.info("Body element found. Waiting for potential resource loading...") | |
| time.sleep(3) # Wait a bit longer for external resources/scripts | |
| # 5) Hide scrollbars via CSS | |
| try: | |
| driver.execute_script( | |
| "document.documentElement.style.overflow = 'hidden';" | |
| "document.body.style.overflow = 'hidden';" | |
| ) | |
| logger.info("Scrollbars hidden via JS.") | |
| except Exception as e: | |
| logger.warning(f"Could not hide scrollbars via JS: {e}") | |
| # 6) Get full page dimensions accurately | |
| try: | |
| scroll_width = driver.execute_script( | |
| "return Math.max(document.body.scrollWidth, document.documentElement.scrollWidth, document.body.offsetWidth, document.documentElement.offsetWidth)" | |
| ) | |
| scroll_height = driver.execute_script( | |
| "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight, document.body.offsetHeight, document.documentElement.offsetHeight)" | |
| ) | |
| logger.info(f"Detected dimensions: width={scroll_width}, height={scroll_height}") | |
| # Ensure minimum dimensions to avoid errors | |
| scroll_width = max(scroll_width, 100) # 最小幅を設定 | |
| scroll_height = max(scroll_height, 100) # 最小高さを設定 | |
| except Exception as e: | |
| logger.error(f"Error getting page dimensions: {e}") | |
| # フォールバックとしてデフォルト値を設定 | |
| scroll_width = 1200 | |
| scroll_height = 800 | |
| logger.warning(f"Falling back to dimensions: width={scroll_width}, height={scroll_height}") | |
| # 7) Calculate adjusted height with user-specified margin | |
| adjusted_height = int(scroll_height * (1 + extension_percentage / 100.0)) | |
| # Ensure adjusted height is not excessively large or small | |
| adjusted_height = max(adjusted_height, scroll_height, 100) # 最小高さを確保 | |
| logger.info(f"Adjusted height calculated: {adjusted_height} (extension: {extension_percentage}%)") | |
| # 8) Set window size to full page dimensions (width) and adjusted height | |
| logger.info(f"Resizing window to: width={scroll_width}, height={adjusted_height}") | |
| driver.set_window_size(scroll_width, adjusted_height) | |
| logger.info("Waiting for layout stabilization after resize...") | |
| time.sleep(3) # Wait longer for layout stabilization | |
| # Scroll to top just in case | |
| try: | |
| driver.execute_script("window.scrollTo(0, 0)") | |
| time.sleep(1) | |
| logger.info("Scrolled to top.") | |
| except Exception as e: | |
| logger.warning(f"Could not scroll to top: {e}") | |
| # 9) Take screenshot | |
| logger.info("Taking screenshot...") | |
| png = driver.get_screenshot_as_png() | |
| logger.info("Screenshot taken successfully.") | |
| # Convert to PIL Image | |
| img = Image.open(BytesIO(png)) | |
| return img | |
| except Exception as e: | |
| logger.error(f"An error occurred during screenshot generation: {e}", exc_info=True) | |
| return Image.new('RGB', (1, 1), color=(0, 0, 0)) # Return black 1x1 image on error | |
| finally: | |
| logger.info("Cleaning up...") | |
| if driver: | |
| try: | |
| driver.quit() | |
| logger.info("WebDriver quit successfully.") | |
| except Exception as e: | |
| logger.error(f"Error quitting WebDriver: {e}") | |
| if tmp_path and os.path.exists(tmp_path): | |
| try: | |
| os.remove(tmp_path) | |
| logger.info(f"Temporary file {tmp_path} removed.") | |
| except Exception as e: | |
| logger.error(f"Error removing temporary file {tmp_path}: {e}") | |
| # --- FastAPI Setup --- | |
| app = FastAPI() | |
| # CORS設定を追加 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 静的ファイルのサービング設定 | |
| # Gradioのディレクトリを探索してアセットを見つける | |
| gradio_dir = os.path.dirname(gr.__file__) | |
| logger.info(f"Gradio version: {gr.__version__}") | |
| logger.info(f"Gradio directory: {gradio_dir}") | |
| # 基本的な静的ファイルディレクトリをマウント | |
| static_dir = os.path.join(gradio_dir, "templates", "frontend", "static") | |
| if os.path.exists(static_dir): | |
| logger.info(f"Mounting static directory: {static_dir}") | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| # _appディレクトリを探す(新しいSvelteKitベースのフロントエンド用) | |
| app_dir = os.path.join(gradio_dir, "templates", "frontend", "_app") | |
| if os.path.exists(app_dir): | |
| logger.info(f"Mounting _app directory: {app_dir}") | |
| app.mount("/_app", StaticFiles(directory=app_dir), name="_app") | |
| # assetsディレクトリを探す | |
| assets_dir = os.path.join(gradio_dir, "templates", "frontend", "assets") | |
| if os.path.exists(assets_dir): | |
| logger.info(f"Mounting assets directory: {assets_dir}") | |
| app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") | |
| # cdnディレクトリがあれば追加 | |
| cdn_dir = os.path.join(gradio_dir, "templates", "cdn") | |
| if os.path.exists(cdn_dir): | |
| logger.info(f"Mounting cdn directory: {cdn_dir}") | |
| app.mount("/cdn", StaticFiles(directory=cdn_dir), name="cdn") | |
| # Pydantic model for API request body validation | |
| class ScreenshotRequest(BaseModel): | |
| html_code: str | |
| extension_percentage: float = 8.0 # Default value same as Gradio slider | |
| # API Endpoint for screenshot generation | |
| async def api_render_screenshot(request: ScreenshotRequest): | |
| """ | |
| API endpoint to render HTML and return a screenshot. | |
| """ | |
| try: | |
| logger.info(f"API request received. Extension: {request.extension_percentage}%") | |
| # Run the blocking Selenium code in a separate thread (FastAPI handles this) | |
| pil_image = render_fullpage_screenshot( | |
| request.html_code, | |
| request.extension_percentage | |
| ) | |
| if pil_image.size == (1, 1): | |
| logger.error("Screenshot generation failed, returning 1x1 image.") | |
| # Optionally return a proper error response instead of 1x1 image | |
| # raise HTTPException(status_code=500, detail="Failed to generate screenshot") | |
| # Convert PIL Image to PNG bytes | |
| img_byte_arr = BytesIO() | |
| pil_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) # Go to the start of the BytesIO buffer | |
| logger.info("Returning screenshot as PNG stream.") | |
| return StreamingResponse(img_byte_arr, media_type="image/png") | |
| except Exception as e: | |
| logger.error(f"API Error: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}") | |
| # --- Gradio Interface Definition --- | |
| iface = gr.Interface( | |
| fn=render_fullpage_screenshot, | |
| inputs=[ | |
| gr.Textbox(lines=15, label="HTMLコード入力"), | |
| gr.Slider(minimum=0, maximum=20, step=1.0, value=8, label="上下高さ拡張率(%)") | |
| ], | |
| outputs=gr.Image(type="pil", label="ページ全体のスクリーンショット"), | |
| title="Full Page Screenshot (高さ拡張調整可能)", | |
| description="HTMLをヘッドレスブラウザでレンダリングし、ページ全体を1枚の画像として取得します。上下のみユーザー指定の余裕(%)を追加します。APIエンドポイントは /api/screenshot で利用可能です。", | |
| allow_flagging="never", | |
| theme=gr.themes.Base() # 明示的にテーマを指定 | |
| ) | |
| # --- Mount Gradio App onto FastAPI --- | |
| app = gr.mount_gradio_app(app, iface, path="/") | |
| # --- Run with Uvicorn (for local testing) --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting Uvicorn server for local development...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |