| import gradio as gr |
| from fastapi import FastAPI, HTTPException, Body |
| from fastapi.responses import StreamingResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from selenium import webdriver |
| from selenium.webdriver.chrome.options import Options |
| from selenium.webdriver.common.by import By |
| from selenium.webdriver.support.ui import WebDriverWait |
| from selenium.webdriver.support import expected_conditions as EC |
| from PIL import Image |
| from io import BytesIO |
| import tempfile |
| import time |
| import os |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| def render_fullpage_screenshot(html_code: str, extension_percentage: float) -> Image.Image: |
| """ |
| Renders HTML code to a full-page screenshot using Selenium. |
| |
| Args: |
| html_code: The HTML source code string. |
| extension_percentage: Percentage of extra space to add vertically (e.g., 4 means 4% total). |
| |
| Returns: |
| A PIL Image object of the screenshot. Returns a 1x1 black image on error. |
| """ |
| tmp_path = None |
| driver = None |
|
|
| |
| try: |
| with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode='w', encoding='utf-8') as tmp_file: |
| tmp_path = tmp_file.name |
| tmp_file.write(html_code) |
| logger.info(f"HTML saved to temporary file: {tmp_path}") |
| except Exception as e: |
| logger.error(f"Error writing temporary HTML file: {e}") |
| return Image.new('RGB', (1, 1), color=(0, 0, 0)) |
|
|
| |
| options = Options() |
| options.add_argument("--headless") |
| options.add_argument("--no-sandbox") |
| options.add_argument("--disable-dev-shm-usage") |
| options.add_argument("--force-device-scale-factor=1") |
| |
| |
| |
|
|
| try: |
| logger.info("Initializing WebDriver...") |
| driver = webdriver.Chrome(options=options) |
| logger.info("WebDriver initialized.") |
|
|
| |
| driver.set_window_size(1200, 800) |
| file_url = "file://" + tmp_path |
| logger.info(f"Navigating to {file_url}") |
| driver.get(file_url) |
|
|
| |
| logger.info("Waiting for body element...") |
| WebDriverWait(driver, 15).until( |
| EC.presence_of_element_located((By.TAG_NAME, "body")) |
| ) |
| logger.info("Body element found. Waiting for potential resource loading...") |
| time.sleep(3) |
|
|
| |
| try: |
| driver.execute_script( |
| "document.documentElement.style.overflow = 'hidden';" |
| "document.body.style.overflow = 'hidden';" |
| ) |
| logger.info("Scrollbars hidden via JS.") |
| except Exception as e: |
| logger.warning(f"Could not hide scrollbars via JS: {e}") |
|
|
| |
| try: |
| scroll_width = driver.execute_script( |
| "return Math.max(document.body.scrollWidth, document.documentElement.scrollWidth, document.body.offsetWidth, document.documentElement.offsetWidth)" |
| ) |
| scroll_height = driver.execute_script( |
| "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight, document.body.offsetHeight, document.documentElement.offsetHeight)" |
| ) |
| logger.info(f"Detected dimensions: width={scroll_width}, height={scroll_height}") |
| |
| scroll_width = max(scroll_width, 100) |
| scroll_height = max(scroll_height, 100) |
|
|
| except Exception as e: |
| logger.error(f"Error getting page dimensions: {e}") |
| |
| scroll_width = 1200 |
| scroll_height = 800 |
| logger.warning(f"Falling back to dimensions: width={scroll_width}, height={scroll_height}") |
|
|
| |
| adjusted_height = int(scroll_height * (1 + extension_percentage / 100.0)) |
| |
| adjusted_height = max(adjusted_height, scroll_height, 100) |
| logger.info(f"Adjusted height calculated: {adjusted_height} (extension: {extension_percentage}%)") |
|
|
| |
| logger.info(f"Resizing window to: width={scroll_width}, height={adjusted_height}") |
| driver.set_window_size(scroll_width, adjusted_height) |
| logger.info("Waiting for layout stabilization after resize...") |
| time.sleep(3) |
|
|
| |
| try: |
| driver.execute_script("window.scrollTo(0, 0)") |
| time.sleep(1) |
| logger.info("Scrolled to top.") |
| except Exception as e: |
| logger.warning(f"Could not scroll to top: {e}") |
|
|
| |
| logger.info("Taking screenshot...") |
| png = driver.get_screenshot_as_png() |
| logger.info("Screenshot taken successfully.") |
|
|
| |
| img = Image.open(BytesIO(png)) |
| return img |
|
|
| except Exception as e: |
| logger.error(f"An error occurred during screenshot generation: {e}", exc_info=True) |
| return Image.new('RGB', (1, 1), color=(0, 0, 0)) |
| finally: |
| logger.info("Cleaning up...") |
| if driver: |
| try: |
| driver.quit() |
| logger.info("WebDriver quit successfully.") |
| except Exception as e: |
| logger.error(f"Error quitting WebDriver: {e}") |
| if tmp_path and os.path.exists(tmp_path): |
| try: |
| os.remove(tmp_path) |
| logger.info(f"Temporary file {tmp_path} removed.") |
| except Exception as e: |
| logger.error(f"Error removing temporary file {tmp_path}: {e}") |
|
|
| |
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| gradio_dir = os.path.dirname(gr.__file__) |
| logger.info(f"Gradio version: {gr.__version__}") |
| logger.info(f"Gradio directory: {gradio_dir}") |
|
|
| |
| static_dir = os.path.join(gradio_dir, "templates", "frontend", "static") |
| if os.path.exists(static_dir): |
| logger.info(f"Mounting static directory: {static_dir}") |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
| |
| app_dir = os.path.join(gradio_dir, "templates", "frontend", "_app") |
| if os.path.exists(app_dir): |
| logger.info(f"Mounting _app directory: {app_dir}") |
| app.mount("/_app", StaticFiles(directory=app_dir), name="_app") |
|
|
| |
| assets_dir = os.path.join(gradio_dir, "templates", "frontend", "assets") |
| if os.path.exists(assets_dir): |
| logger.info(f"Mounting assets directory: {assets_dir}") |
| app.mount("/assets", StaticFiles(directory=assets_dir), name="assets") |
|
|
| |
| cdn_dir = os.path.join(gradio_dir, "templates", "cdn") |
| if os.path.exists(cdn_dir): |
| logger.info(f"Mounting cdn directory: {cdn_dir}") |
| app.mount("/cdn", StaticFiles(directory=cdn_dir), name="cdn") |
|
|
| |
| class ScreenshotRequest(BaseModel): |
| html_code: str |
| extension_percentage: float = 8.0 |
|
|
| |
| @app.post("/api/screenshot", |
| response_class=StreamingResponse, |
| tags=["Screenshot"], |
| summary="Render HTML to Full Page Screenshot", |
| description="Takes HTML code and an optional vertical extension percentage, renders it using a headless browser, and returns the full-page screenshot as a PNG image.") |
| async def api_render_screenshot(request: ScreenshotRequest): |
| """ |
| API endpoint to render HTML and return a screenshot. |
| """ |
| try: |
| logger.info(f"API request received. Extension: {request.extension_percentage}%") |
| |
| pil_image = render_fullpage_screenshot( |
| request.html_code, |
| request.extension_percentage |
| ) |
|
|
| if pil_image.size == (1, 1): |
| logger.error("Screenshot generation failed, returning 1x1 image.") |
| |
| |
|
|
| |
| img_byte_arr = BytesIO() |
| pil_image.save(img_byte_arr, format='PNG') |
| img_byte_arr.seek(0) |
|
|
| logger.info("Returning screenshot as PNG stream.") |
| return StreamingResponse(img_byte_arr, media_type="image/png") |
|
|
| except Exception as e: |
| logger.error(f"API Error: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}") |
|
|
| |
| iface = gr.Interface( |
| fn=render_fullpage_screenshot, |
| inputs=[ |
| gr.Textbox(lines=15, label="HTMLコード入力"), |
| gr.Slider(minimum=0, maximum=20, step=1.0, value=8, label="上下高さ拡張率(%)") |
| ], |
| outputs=gr.Image(type="pil", label="ページ全体のスクリーンショット"), |
| title="Full Page Screenshot (高さ拡張調整可能)", |
| description="HTMLをヘッドレスブラウザでレンダリングし、ページ全体を1枚の画像として取得します。上下のみユーザー指定の余裕(%)を追加します。APIエンドポイントは /api/screenshot で利用可能です。", |
| allow_flagging="never", |
| theme=gr.themes.Base() |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, iface, path="/") |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| logger.info("Starting Uvicorn server for local development...") |
| uvicorn.run(app, host="0.0.0.0", port=7860) |