Spaces:
Paused
Paused
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.options import Options | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.support.ui import WebDriverWait | |
| from selenium.webdriver.support import expected_conditions as EC | |
| from PIL import Image | |
| from io import BytesIO | |
| import tempfile | |
| import time | |
| import os | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| # --- Gemini SDK (v1.x) --------------------------------- | |
| from google import genai | |
| from google.genai import types | |
| # ------------------------------------------------------- | |
| # ロギング設定 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ---------- データモデル ---------- | |
| class GeminiRequest(BaseModel): | |
| text: str | |
| extension_percentage: float = 10.0 | |
| temperature: float = 0.5 | |
| trim_whitespace: bool = True | |
| style: str = "standard" | |
| class ScreenshotRequest(BaseModel): | |
| html_code: str | |
| extension_percentage: float = 10.0 | |
| trim_whitespace: bool = True | |
| style: str = "standard" | |
| # ---------- ユーティリティ ---------- | |
| def enhance_font_awesome_layout(html_code): | |
| fa_fix_css = """ | |
| <style> | |
| [class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;} | |
| h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{ | |
| vertical-align:middle!important;margin-right:10px!important;} | |
| .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far,span+.fab{ | |
| display:inline-block!important;margin-left:5px!important;} | |
| .card [class*="fa-"],.card-body [class*="fa-"]{float:none!important;clear:none!important;position:relative!important;} | |
| li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;} | |
| .inline-icon{display:inline-flex!important;align-items:center!important;justify-content:flex-start!important;} | |
| [class*="fa-"]+span{display:inline-block!important;vertical-align:middle!important;} | |
| </style> | |
| """ | |
| if '<head>' in html_code: | |
| return html_code.replace('</head>', f'{fa_fix_css}</head>') | |
| elif '<html' in html_code: | |
| head_end = html_code.find('</head>') | |
| if head_end > 0: | |
| return html_code[:head_end] + fa_fix_css + html_code[head_end:] | |
| body_start = html_code.find('<body') | |
| if body_start > 0: | |
| return html_code[:body_start] + f'<head>{fa_fix_css}</head>' + html_code[body_start:] | |
| return f'<html><head>{fa_fix_css}</head>' + html_code + '</html>' | |
| def load_system_instruction(style="standard"): | |
| valid_styles = ["standard", "cute", "resort", "cool", "dental"] | |
| if style not in valid_styles: | |
| logger.warning(f"無効なスタイル '{style}'。'standard' を使用") | |
| style = "standard" | |
| local_path = os.path.join(os.path.dirname(__file__), style, "prompt.txt") | |
| if os.path.exists(local_path): | |
| with open(local_path, encoding="utf-8") as f: | |
| return f.read() | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id="tomo2chin2/GURAREKOstlyle", | |
| filename=f"{style}/prompt.txt", | |
| repo_type="dataset" | |
| ) | |
| with open(file_path, encoding="utf-8") as f: | |
| return f.read() | |
| except Exception: | |
| file_path = hf_hub_download( | |
| repo_id="tomo2chin2/GURAREKOstlyle", | |
| filename="prompt.txt", | |
| repo_type="dataset" | |
| ) | |
| with open(file_path, encoding="utf-8") as f: | |
| return f.read() | |
| # ---------- Gemini HTML 生成 ---------- | |
| def generate_html_from_text(text, temperature=0.3, style="standard"): | |
| """Gemini で HTML を生成。2.5 Flash Preview の場合は thinking_off""" | |
| api_key = os.environ.get("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY が未設定") | |
| model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro") | |
| client = genai.Client(api_key=api_key) | |
| system_instruction = load_system_instruction(style) | |
| prompt = f"{system_instruction}\n\n{text}" | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
| ] | |
| # ---- モデル分岐 ---- | |
| if model_name == "gemini-2.5-flash-preview-04-17": | |
| generation_cfg = types.GenerateContentConfig( | |
| temperature=temperature, | |
| top_p=0.7, | |
| top_k=20, | |
| max_output_tokens=8192, | |
| candidate_count=1, | |
| thinking_config=types.ThinkingConfig(thinking_budget=0) # thinking OFF | |
| ) | |
| else: | |
| generation_cfg = types.GenerateContentConfig( | |
| temperature=temperature, | |
| top_p=0.7, | |
| top_k=20, | |
| max_output_tokens=8192, | |
| candidate_count=1, | |
| ) | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config=generation_cfg, | |
| safety_settings=safety_settings | |
| ) | |
| raw = response.text | |
| # ```html ... ``` 抽出 | |
| start = raw.find("```html") | |
| end = raw.rfind("```") | |
| if start != -1 and end != -1 and start < end: | |
| html_code = raw[start + 7:end].strip() | |
| return enhance_font_awesome_layout(html_code) | |
| logger.warning("```html``` ブロックが見つからず全文返却") | |
| return raw | |
| # ---------- 画像トリミング ---------- | |
| def trim_image_whitespace(image, threshold=250, padding=10): | |
| gray = image.convert("L") | |
| data = list(gray.getdata()) | |
| w, h = gray.size | |
| pixels = [data[i * w:(i + 1) * w] for i in range(h)] | |
| min_x, min_y, max_x, max_y = w, h, 0, 0 | |
| for y in range(h): | |
| for x in range(w): | |
| if pixels[y][x] < threshold: | |
| min_x, min_y = min(min_x, x), min(min_y, y) | |
| max_x, max_y = max(max_x, x), max(max_y, y) | |
| if min_x > max_x or min_y > max_y: | |
| return image | |
| min_x = max(0, min_x - padding) | |
| min_y = max(0, min_y - padding) | |
| max_x = min(w - 1, max_x + padding) | |
| max_y = min(h - 1, max_y + padding) | |
| return image.crop((min_x, min_y, max_x + 1, max_y + 1)) | |
| # ---------- HTML → スクリーンショット(Selenium) ---------- | |
| def render_fullpage_screenshot(html_code, extension_percentage=6.0, trim_whitespace=True): | |
| tmp_path, driver = None, None | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as f: | |
| tmp_path = f.name | |
| f.write(html_code) | |
| options = Options() | |
| options.add_argument("--headless") | |
| options.add_argument("--no-sandbox") | |
| options.add_argument("--disable-dev-shm-usage") | |
| driver = webdriver.Chrome(options=options) | |
| driver.set_window_size(1200, 1000) | |
| driver.get("file://" + tmp_path) | |
| WebDriverWait(driver, 15).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) | |
| time.sleep(3) | |
| total_height = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)") | |
| viewport = driver.execute_script("return window.innerHeight") | |
| for i in range(max(1, total_height // viewport) + 1): | |
| driver.execute_script(f"window.scrollTo(0, {i * (viewport - 200)})") | |
| time.sleep(0.2) | |
| driver.execute_script("window.scrollTo(0, 0)") | |
| driver.execute_script("document.documentElement.style.overflow='hidden';document.body.style.overflow='hidden'") | |
| dims = driver.execute_script("return {w:document.documentElement.scrollWidth,h:document.documentElement.scrollHeight}") | |
| driver.set_window_size(dims["w"], int(dims["h"] * (1 + extension_percentage / 100))) | |
| time.sleep(1) | |
| png = driver.get_screenshot_as_png() | |
| img = Image.open(BytesIO(png)) | |
| return trim_image_whitespace(img, 248, 20) if trim_whitespace else img | |
| except Exception as e: | |
| logger.error(f"Selenium error: {e}", exc_info=True) | |
| return Image.new("RGB", (1, 1)) | |
| finally: | |
| if driver: | |
| driver.quit() | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| # ---------- 統合 ---------- | |
| def text_to_screenshot(text, ext, temp=0.3, trim=True, style="standard"): | |
| html = generate_html_from_text(text, temp, style) | |
| return render_fullpage_screenshot(html, ext, trim) | |
| # ---------- FastAPI ---------- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # Gradio 静的ファイル | |
| gradio_dir = os.path.dirname(gr.__file__) | |
| app.mount("/static", StaticFiles(directory=os.path.join(gradio_dir, "templates/frontend/static")), name="static") | |
| # ---------- API ---------- | |
| async def api_render(req: ScreenshotRequest): | |
| img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace) | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| async def api_text_screenshot(req: GeminiRequest): | |
| img = text_to_screenshot(req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style) | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| # ---------- Gradio ---------- | |
| def process_input(mode, txt, ext, temp, trim, style): | |
| if mode == "HTML入力": | |
| return render_fullpage_screenshot(txt, ext, trim) | |
| return text_to_screenshot(txt, ext, temp, trim, style) | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Base()) as iface: | |
| # 見出し | |
| gr.Markdown( | |
| "<h1 style='text-align:center;margin:0.2em 0'>HTMLビューア & テキスト→インフォグラフィック変換</h1>", | |
| elem_id="title", | |
| inline=True, | |
| ) | |
| gr.Markdown( | |
| "HTML を直接レンダリングするか、テキストを Gemini API でインフォグラフィックに変換して画像取得できます。", | |
| elem_id="subtitle", | |
| ) | |
| # === 入力モード選択 === | |
| with gr.Row(): | |
| input_mode = gr.Radio( | |
| ["HTML入力", "テキスト入力"], | |
| value="HTML入力", | |
| label="入力モード", | |
| ) | |
| # === 共通入力テキスト === | |
| input_text = gr.Textbox( | |
| lines=15, | |
| label="入力", | |
| placeholder="HTMLコードまたはテキストを入力してください(モードに応じて処理)。", | |
| ) | |
| # === スタイル + スライダー類を 2 カラムで === | |
| with gr.Row(): | |
| # 左:スタイル | |
| with gr.Column(scale=1, min_width=180): | |
| style_dropdown = gr.Dropdown( | |
| choices=["standard", "cute", "resort", "cool", "dental"], | |
| value="standard", | |
| label="デザインスタイル", | |
| info="テキスト→HTML 変換時のテーマ", | |
| visible=False, # 初期は非表示 | |
| ) | |
| # 右:拡張率 & 温度 | |
| with gr.Column(scale=3): | |
| extension_percentage = gr.Slider( | |
| 0, 30, value=10, step=1, | |
| label="上下高さ拡張率(%)" | |
| ) | |
| temperature = gr.Slider( | |
| 0.0, 1.0, value=0.5, step=0.1, | |
| label="生成時の温度(低い=一貫性高、高い=創造性高)", | |
| visible=False, # 初期は非表示 | |
| ) | |
| # === トリミング & ボタン === | |
| trim_whitespace = gr.Checkbox( | |
| value=True, | |
| label="余白を自動トリミング", | |
| ) | |
| submit_btn = gr.Button("生成", variant="primary", size="lg") | |
| # === 出力 === | |
| output_image = gr.Image( | |
| type="pil", | |
| label="ページ全体のスクリーンショット", | |
| show_label=True, | |
| show_download_button=True, | |
| ) | |
| # ---------- 変化イベント ---------- | |
| def toggle_controls(mode): | |
| """テキストモードのときだけ温度とスタイルを表示""" | |
| is_text = mode == "テキスト入力" | |
| return ( | |
| gr.update(visible=is_text), # temperature | |
| gr.update(visible=is_text), # style_dropdown | |
| ) | |
| input_mode.change( | |
| fn=toggle_controls, | |
| inputs=input_mode, | |
| outputs=[temperature, style_dropdown], | |
| ) | |
| submit_btn.click( | |
| fn=process_input, | |
| inputs=[ | |
| input_mode, | |
| input_text, | |
| extension_percentage, | |
| temperature, | |
| trim_whitespace, | |
| style_dropdown, | |
| ], | |
| outputs=output_image, | |
| ) | |
| # 下部にモデル表示 | |
| gr.Markdown( | |
| f""" | |
| **使用モデル** : `{os.getenv("GEMINI_MODEL", "gemini-1.5-pro")}` | |
| **API** : `/api/screenshot` / `/api/text-to-screenshot` | |
| """, | |
| elem_id="footnote", | |
| ) | |
| # --- Gradio を FastAPI にマウント --- | |
| app = gr.mount_gradio_app(app, iface, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |