Spaces:
Paused
Paused
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from selenium import webdriver | |
| from selenium.webdriver.chrome.options import Options | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.support.ui import WebDriverWait | |
| from selenium.webdriver.support import expected_conditions as EC | |
| from PIL import Image | |
| from io import BytesIO | |
| import tempfile, time, os, logging | |
| from huggingface_hub import hf_hub_download | |
| # ---------- Gemini SDK (v1.x) ---------- | |
| from google import genai # :contentReference[oaicite:4]{index=4} | |
| from google.genai import types | |
| # --------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ---------- Pydantic モデル ---------- | |
| class GeminiRequest(BaseModel): | |
| text: str | |
| extension_percentage: float = 10.0 | |
| temperature: float = 0.5 | |
| trim_whitespace: bool = True | |
| style: str = "standard" | |
| class ScreenshotRequest(BaseModel): | |
| html_code: str | |
| extension_percentage: float = 10.0 | |
| trim_whitespace: bool = True | |
| # -------------------------------------- | |
| # ---------- ユーティリティ ---------- | |
| def enhance_font_awesome_layout(html_code: str) -> str: | |
| fix_css = """ | |
| <style>[class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;} | |
| h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{ | |
| vertical-align:middle!important;margin-right:10px!important;} | |
| .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far,span+.fab{ | |
| display:inline-block!important;margin-left:5px!important;} | |
| li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;}</style>""" | |
| if "<head>" in html_code: | |
| return html_code.replace("</head>", f"{fix_css}</head>") | |
| return f"<html><head>{fix_css}</head>{html_code}</html>" | |
| def load_system_instruction(style="standard") -> str: | |
| styles = ["standard", "cute", "resort", "cool", "dental"] | |
| if style not in styles: | |
| style = "standard" | |
| local = os.path.join(os.path.dirname(__file__), style, "prompt.txt") | |
| if os.path.exists(local): | |
| with open(local, encoding="utf-8") as f: | |
| return f.read() | |
| # HF fallback | |
| file_path = hf_hub_download( | |
| repo_id="tomo2chin2/GURAREKOstlyle", | |
| filename=f"{style}/prompt.txt" if style != "standard" else "prompt.txt", | |
| repo_type="dataset" | |
| ) | |
| with open(file_path, encoding="utf-8") as f: | |
| return f.read() | |
| # -------------------------------------- | |
| # ---------- Gemini → HTML ---------- | |
| def generate_html_from_text(text, temperature=0.3, style="standard") -> str: | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY 未設定") | |
| model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-pro") | |
| client = genai.Client(api_key=api_key) | |
| prompt = f"{load_system_instruction(style)}\n\n{text}" | |
| if model_name == "gemini-2.5-flash-preview-04-17": # thinking OFF :contentReference[oaicite:5]{index=5} | |
| cfg = types.GenerateContentConfig( | |
| temperature=temperature, top_p=0.7, top_k=20, | |
| max_output_tokens=8192, candidate_count=1, | |
| thinking_config=types.ThinkingConfig(thinking_budget=0) | |
| ) | |
| else: | |
| cfg = types.GenerateContentConfig( | |
| temperature=temperature, top_p=0.7, top_k=20, | |
| max_output_tokens=8192, candidate_count=1 | |
| ) | |
| raw = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config=cfg | |
| ).text | |
| s, e = raw.find("```html"), raw.rfind("```") | |
| if s != -1 and e != -1 and s < e: | |
| html = raw[s + 7:e].strip() | |
| return enhance_font_awesome_layout(html) | |
| return raw | |
| # -------------------------------------- | |
| # ---------- 画像トリミング ---------- | |
| def trim_image_whitespace(img: Image.Image, threshold=248, padding=20): | |
| g = img.convert("L") | |
| w, h = g.size | |
| pix = list(g.getdata()) | |
| pix = [pix[i*w:(i+1)*w] for i in range(h)] | |
| xs, ys = [w], [h] | |
| xe = ye = -1 | |
| for y in range(h): | |
| for x in range(w): | |
| if pix[y][x] < threshold: | |
| xs.append(x); ys.append(y); xe = max(xe, x); ye = max(ye, y) | |
| if xe == -1: | |
| return img | |
| x0, y0 = max(0, min(xs)-padding), max(0, min(ys)-padding) | |
| x1, y1 = min(w, xe+padding+1), min(h, ye+padding+1) | |
| return img.crop((x0, y0, x1, y1)) | |
| # -------------------------------------- | |
| # ---------- Selenium スクショ ---------- | |
| def render_fullpage_screenshot(html, ext=6.0, trim=True) -> Image.Image: | |
| tmp, driver = None, None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".html", mode="w", encoding="utf-8") as f: | |
| f.write(html); tmp = f.name | |
| opts = Options() | |
| opts.add_argument("--headless=new") # 新 headless フラグ :contentReference[oaicite:6]{index=6} | |
| opts.add_argument("--no-sandbox"); opts.add_argument("--disable-dev-shm-usage") | |
| driver = webdriver.Chrome(options=opts) | |
| driver.set_window_size(1200, 1000) | |
| driver.get("file://" + tmp) | |
| WebDriverWait(driver, 15).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) | |
| time.sleep(1) | |
| total = driver.execute_script("return Math.max(document.body.scrollHeight,document.documentElement.scrollHeight)") | |
| vp = driver.execute_script("return window.innerHeight") | |
| for y in range(0, total, vp-200): | |
| driver.execute_script(f"window.scrollTo(0,{y})"); time.sleep(0.1) | |
| driver.execute_script("window.scrollTo(0,0)") | |
| driver.execute_script("document.documentElement.style.overflow='hidden'") | |
| w = driver.execute_script("return document.documentElement.scrollWidth") | |
| h = driver.execute_script("return document.documentElement.scrollHeight") | |
| driver.set_window_size(w, int(h*(1+ext/100))) | |
| time.sleep(0.5) | |
| img = Image.open(BytesIO(driver.get_screenshot_as_png())) | |
| return trim_image_whitespace(img) if trim else img | |
| except Exception as e: | |
| logger.error(e, exc_info=True) | |
| return Image.new("RGB", (1,1)) | |
| finally: | |
| if driver: driver.quit() | |
| if tmp and os.path.exists(tmp): os.remove(tmp) | |
| # -------------------------------------- | |
| def text_to_screenshot(text, ext, temp, trim, style): | |
| html = generate_html_from_text(text, temp, style) | |
| return render_fullpage_screenshot(html, ext, trim) | |
| # ---------- FastAPI ---------- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, allow_origins=["*"], allow_methods=["*"], | |
| allow_headers=["*"], allow_credentials=True | |
| ) | |
| async def api_screen(req: ScreenshotRequest): | |
| img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace) | |
| buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| async def api_text(req: GeminiRequest): | |
| img = text_to_screenshot(req.text, req.extension_percentage, | |
| req.temperature, req.trim_whitespace, req.style) | |
| buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) | |
| return StreamingResponse(buf, media_type="image/png") | |
| # -------------------------------------- | |
| # ---------- Gradio UI ---------- | |
| def process(mode, txt, ext, temp, trim, style): | |
| if mode == "HTML入力": | |
| return render_fullpage_screenshot(txt, ext, trim) | |
| return text_to_screenshot(txt, ext, temp, trim, style) | |
| with gr.Blocks(theme=gr.themes.Base(), title="HTML Viewer & Text→Infographic") as demo: | |
| gr.Markdown("## HTMLビューア & テキスト→インフォグラフィック変換") # central heading :contentReference[oaicite:7]{index=7} | |
| with gr.Row(): # 横一列配置 :contentReference[oaicite:8]{index=8} | |
| mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード") | |
| with gr.Row(): # 入力パネル & 出力画像 | |
| with gr.Column(scale=5): | |
| txt = gr.Textbox(lines=15, label="入力") | |
| with gr.Row(): | |
| style_dd = gr.Dropdown(["standard","cute","resort","cool","dental"], | |
| value="standard", label="デザインスタイル", visible=False) | |
| temp_sl = gr.Slider(0,1,step=0.1,value=0.5,label="生成温度",visible=False) | |
| ext_sl = gr.Slider(0,30,step=1,value=10,label="高さ拡張率(%)") | |
| trim_cb = gr.Checkbox(value=True,label="余白トリミング") | |
| gen_btn = gr.Button("生成", variant="primary") | |
| with gr.Column(scale=7): | |
| out_img = gr.Image(type="pil", label="プレビュー", height=540) | |
| # モード切替で可視/不可視を更新 | |
| def _toggle(m): vis = m=="テキスト入力"; return [gr.update(visible=vis), gr.update(visible=vis)] | |
| mode.change(_toggle, mode, [temp_sl, style_dd]) | |
| gen_btn.click(process, [mode, txt, ext_sl, temp_sl, trim_cb, style_dd], out_img) | |
| gr.Markdown( | |
| f""" | |
| **使用モデル** : `{os.getenv('GEMINI_MODEL','gemini-1.5-pro')}` | |
| `/api/screenshot` ・ `/api/text-to-screenshot` | |
| """ | |
| ) | |
| # ---------- マウント ---------- | |
| demo_app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(demo_app, host="0.0.0.0", port=7860) | |