# ──────────────────────────────────── # 基本ライブラリ # ──────────────────────────────────── import os import time import tempfile import logging from io import BytesIO from typing import List, Optional # ──────────────────────────────────── # サードパーティ # ──────────────────────────────────── import gradio as gr from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from selenium import webdriver from selenium.webdriver.chrome.options import Options from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from PIL import Image from huggingface_hub import hf_hub_download # ▶ 新しい Gemini SDK from google import genai # google‑genai ≥1.11.0 from google.genai import types # 型オブジェクト # ──────────────────────────────────── # ロギング設定 # ──────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ──────────────────────────────────── # Pydantic モデル # ──────────────────────────────────── class GeminiRequest(BaseModel): text: str extension_percentage: float = 10.0 temperature: float = 0.5 trim_whitespace: bool = True style: str = "standard" class ScreenshotRequest(BaseModel): html_code: str extension_percentage: float = 10.0 trim_whitespace: bool = True style: str = "standard" # ──────────────────────────────────── # ① Font Awesome レイアウト調整 # ──────────────────────────────────── def enhance_font_awesome_layout(html_code: str) -> str: """Font Awesome アイコンの表示ズレを修正する追加 CSS を挿入""" fa_fix_css = """ """ if "" in html_code: return html_code.replace("", f"{fa_fix_css}") elif "") if head_end > 0: return html_code[:head_end] + fa_fix_css + html_code[head_end:] body_start = html_code.find(" 0: return html_code[:body_start] + f"{fa_fix_css}" + html_code[body_start:] return f"{fa_fix_css}{html_code}" # ──────────────────────────────────── # ② システムプロンプト読み込み # ──────────────────────────────────── def load_system_instruction(style: str = "standard") -> str: """style ごとの prompt.txt をローカル or HF Hub から取得""" styles = ["standard", "cute", "resort", "cool", "dental"] if style not in styles: logger.warning(f"無効なスタイル '{style}' → 'standard' を使用") style = "standard" # ローカル first local_path = os.path.join(os.path.dirname(__file__), style, "prompt.txt") if os.path.exists(local_path): with open(local_path, encoding="utf-8") as f: return f.read() # HF Hub fallback try: file_path = hf_hub_download( repo_id="tomo2chin2/GURAREKOstlyle", filename=f"{style}/prompt.txt", repo_type="dataset", ) with open(file_path, encoding="utf-8") as f: return f.read() except Exception as e: logger.error(f"prompt.txt 取得失敗: {e}") raise # ──────────────────────────────────── # ③ 画像の空白トリミング # ──────────────────────────────────── def trim_image_whitespace( image: Image.Image, threshold: int = 250, padding: int = 10 ) -> Image.Image: """白余白を検出しパディングを残して切り詰める""" gray = image.convert("L") data = gray.getdata() w, h = gray.size min_x, min_y, max_x, max_y = w, h, 0, 0 pixels = list(data) pixels = [pixels[i * w : (i + 1) * w] for i in range(h)] for y in range(h): for x in range(w): if pixels[y][x] < threshold: min_x, min_y = min(min_x, x), min(min_y, y) max_x, max_y = max(max_x, x), max(max_y, y) if min_x > max_x: return image min_x, min_y = max(0, min_x - padding), max(0, min_y - padding) max_x, max_y = min(w - 1, max_x + padding), min(h - 1, max_y + padding) return image.crop((min_x, min_y, max_x + 1, max_y + 1)) # ──────────────────────────────────── # ④ Selenium でフルページ SS 生成 # ──────────────────────────────────── def render_fullpage_screenshot( html_code: str, extension_percentage: float = 6.0, trim_whitespace: bool = True ) -> Image.Image: """HTML 文字列 → full‑page PNG → PIL.Image""" tmp_path: Optional[str] = None driver: Optional[webdriver.Chrome] = None try: with tempfile.NamedTemporaryFile( suffix=".html", delete=False, mode="w", encoding="utf-8" ) as tmp: tmp.write(html_code) tmp_path = tmp.name options = Options() options.add_argument("--headless") options.add_argument("--no-sandbox") options.add_argument("--disable-dev-shm-usage") options.add_argument("--force-device-scale-factor=1") driver = webdriver.Chrome(options=options) driver.set_window_size(1200, 1000) driver.get(f"file://{tmp_path}") WebDriverWait(driver, 15).until( EC.presence_of_element_located((By.TAG_NAME, "body")) ) time.sleep(3) # 初期ロード待ち # 縦スクロールしてレンダリング確定 total = driver.execute_script( "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight);" ) vp = driver.execute_script("return window.innerHeight;") for i in range(max(1, total // vp) + 1): driver.execute_script(f"window.scrollTo(0, {i*(vp-200)});") time.sleep(0.2) driver.execute_script("window.scrollTo(0,0);") time.sleep(1) # 全体高さに余白を追加 total = driver.execute_script( "return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight);" ) height = int(total * (1 + extension_percentage / 100)) width = driver.execute_script( "return Math.max(document.documentElement.scrollWidth, document.body.scrollWidth);" ) height = min(max(height, 100), 4000) width = min(max(width, 100), 2000) driver.set_window_size(width, height) time.sleep(0.5) png = driver.get_screenshot_as_png() img = Image.open(BytesIO(png)) if trim_whitespace: img = trim_image_whitespace(img, threshold=248, padding=20) return img except Exception as e: logger.error(f"Screenshot Error: {e}", exc_info=True) return Image.new("RGB", (1, 1), (0, 0, 0)) finally: if driver: try: driver.quit() except Exception: pass if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) # ──────────────────────────────────── # ⑤ Gemini → HTML 生成 # ──────────────────────────────────── def _genai_client(api_key: str) -> genai.Client: return genai.Client(api_key=api_key) def _default_safety() -> List[types.SafetySetting]: return [ types.SafetySetting( category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_MEDIUM_AND_ABOVE" ), types.SafetySetting( category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_MEDIUM_AND_ABOVE" ), types.SafetySetting( category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_MEDIUM_AND_ABOVE" ), types.SafetySetting( category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_MEDIUM_AND_ABOVE" ), ] def generate_html_from_text( text: str, temperature: float = 0.3, style: str = "standard" ) -> str: """ Gemini モデルから HTML コードを返す。 * 環境変数 GEMINI_MODEL が gemini‑2.5‑flash-preview‑04-17 の場合 -> thinking_budget=0 を付けて呼び出し """ api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY が設定されていません") model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-pro") client = _genai_client(api_key) gen_cfg = types.GenerationConfig( temperature=temperature, top_p=0.7, top_k=20, max_output_tokens=8192, candidate_count=1, ) safety_cfg = _default_safety() think_cfg = ( types.ThinkingConfig(thinking_budget=0) if model_name == "gemini-2.5-flash-preview-04-17" else None ) req_cfg_kwargs = dict( generation_config=gen_cfg, safety_settings=safety_cfg, ) if think_cfg: req_cfg_kwargs["thinking_config"] = think_cfg req_cfg = types.GenerateContentConfig(**req_cfg_kwargs) prompt = f"{load_system_instruction(style)}\n\n{text}" logger.info( f"Gemini request → model={model_name}, temp={temperature}, thinking_budget={0 if think_cfg else None}" ) rsp = client.models.generate_content( model=model_name, contents=prompt, config=req_cfg ) raw = rsp.text or "" start = raw.find("```html") end = raw.rfind("```") if 0 <= start < end: html_code = raw[start + 7 : end].strip() return enhance_font_awesome_layout(html_code) logger.warning("```html``` ブロックなし。生レスポンスを返します") return raw # ──────────────────────────────────── # ⑥ テキスト → スクリーンショット統合 # ──────────────────────────────────── def text_to_screenshot( text: str, extension_percentage: float, temperature: float = 0.3, trim_whitespace: bool = True, style: str = "standard", ) -> Image.Image: try: html = generate_html_from_text(text, temperature, style) return render_fullpage_screenshot(html, extension_percentage, trim_whitespace) except Exception as e: logger.error(e, exc_info=True) return Image.new("RGB", (1, 1), (0, 0, 0)) # ──────────────────────────────────── # ⑦ FastAPI セットアップ # ──────────────────────────────────── app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Gradio の静的ファイルをマウント gradio_dir = os.path.dirname(gr.__file__) for sub in [ ("static", "templates/frontend/static"), ("_app", "templates/frontend/_app"), ("assets", "templates/frontend/assets"), ("cdn", "templates/cdn"), ]: target = os.path.join(gradio_dir, sub[1]) if os.path.exists(target): app.mount(f"/{sub[0]}", StaticFiles(directory=target), name=sub[0]) logger.info(f"Mounted /{sub[0]} → {target}") # ──────────────────────────────────── # ⑧ API エンドポイント # ──────────────────────────────────── @app.post( "/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"], summary="HTML → Full‑page Screenshot", ) async def api_render_screenshot(req: ScreenshotRequest): img = render_fullpage_screenshot( req.html_code, req.extension_percentage, req.trim_whitespace ) buf = BytesIO() img.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") @app.post( "/api/text-to-screenshot", response_class=StreamingResponse, tags=["Screenshot", "Gemini"], summary="Text → Gemini → Infographic Screenshot", ) async def api_text_to_screenshot(req: GeminiRequest): img = text_to_screenshot( req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style, ) buf = BytesIO() img.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") # ──────────────────────────────────── # ⑨ Gradio UI # ──────────────────────────────────── def process_input( input_mode, input_text, extension_percentage, temperature, trim_whitespace, style ): if input_mode == "HTML入力": return render_fullpage_screenshot( input_text, extension_percentage, trim_whitespace ) return text_to_screenshot( input_text, extension_percentage, temperature, trim_whitespace, style ) with gr.Blocks(title="Full Page Screenshot + Gemini 2.5 Flash") as iface: gr.Markdown("## HTML ビューア & テキスト → インフォグラフィック") input_mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード") input_text = gr.Textbox(lines=15, label="入力") with gr.Row(): style_dd = gr.Dropdown( ["standard", "cute", "resort", "cool", "dental"], value="standard", label="デザインスタイル", visible=False, ) extension_slider = gr.Slider(0, 30, 10, label="上下高さ拡張率(%)") temperature_slider = gr.Slider( 0.0, 1.0, 0.5, step=0.1, label="生成温度", visible=False, ) trim_cb = gr.Checkbox(value=True, label="余白自動トリミング") btn = gr.Button("生成") out_img = gr.Image(type="pil", label="スクリーンショット") def _vis(mode): is_text = mode == "テキスト入力" return [ {"visible": is_text, "__type__": "update"}, {"visible": is_text, "__type__": "update"}, ] input_mode.change(_vis, input_mode, [temperature_slider, style_dd]) btn.click( process_input, [ input_mode, input_text, extension_slider, temperature_slider, trim_cb, style_dd, ], out_img, ) gr.Markdown( f""" ### 環境 * 使用モデル: `{os.getenv('GEMINI_MODEL', 'gemini-1.5-pro')}` * thinking_budget=0 は `gemini-2.5-flash-preview-04-17` 使用時のみ自動付与 """ ) # ──────────────────────────────────── # ⑩ FastAPI に Gradio をマウント # ──────────────────────────────────── app = gr.mount_gradio_app(app, iface, path="/") # ──────────────────────────────────── # ⑪ 直接実行時 # ──────────────────────────────────── if __name__ == "__main__": import uvicorn logger.info("Starting dev server at http://localhost:7860") uvicorn.run(app, host="0.0.0.0", port=7860)