| | import gradio as gr |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.responses import StreamingResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | from selenium import webdriver |
| | from selenium.webdriver.chrome.options import Options |
| | from selenium.webdriver.common.by import By |
| | from selenium.webdriver.support.ui import WebDriverWait |
| | from selenium.webdriver.support import expected_conditions as EC |
| | from PIL import Image |
| | from io import BytesIO |
| | import tempfile |
| | import time |
| | import os |
| | import logging |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | from google import genai |
| | from google.genai import types |
| | |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | class GeminiRequest(BaseModel): |
| | text: str |
| | extension_percentage: float = 10.0 |
| | temperature: float = 0.5 |
| | trim_whitespace: bool = True |
| | style: str = "standard" |
| |
|
| | class ScreenshotRequest(BaseModel): |
| | html_code: str |
| | extension_percentage: float = 10.0 |
| | trim_whitespace: bool = True |
| | style: str = "standard" |
| |
|
| | |
| | def enhance_font_awesome_layout(html_code): |
| | fa_fix_css = """ |
| | <style> |
| | [class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;} |
| | h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{ |
| | vertical-align:middle!important;margin-right:10px!important;} |
| | .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far,span+.fab{ |
| | display:inline-block!important;margin-left:5px!important;} |
| | .card [class*="fa-"],.card-body [class*="fa-"]{float:none!important;clear:none!important;position:relative!important;} |
| | li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;} |
| | .inline-icon{display:inline-flex!important;align-items:center!important;justify-content:flex-start!important;} |
| | [class*="fa-"]+span{display:inline-block!important;vertical-align:middle!important;} |
| | </style> |
| | """ |
| | if '<head>' in html_code: |
| | return html_code.replace('</head>', f'{fa_fix_css}</head>') |
| | elif '<html' in html_code: |
| | head_end = html_code.find('</head>') |
| | if head_end > 0: |
| | return html_code[:head_end] + fa_fix_css + html_code[head_end:] |
| | body_start = html_code.find('<body') |
| | if body_start > 0: |
| | return html_code[:body_start] + f'<head>{fa_fix_css}</head>' + html_code[body_start:] |
| | return f'<html><head>{fa_fix_css}</head>' + html_code + '</html>' |
| |
|
| | def load_system_instruction(style="standard"): |
| | valid_styles = ["standard", "cute", "resort", "cool", "dental"] |
| | if style not in valid_styles: |
| | logger.warning(f"無効なスタイル '{style}'。'standard' を使用") |
| | style = "standard" |
| | local_path = os.path.join(os.path.dirname(__file__), style, "prompt.txt") |
| | if os.path.exists(local_path): |
| | with open(local_path, encoding="utf-8") as f: |
| | return f.read() |
| | try: |
| | file_path = hf_hub_download( |
| | repo_id="tomo2chin2/GURAREKOstlyle", |
| | filename=f"{style}/prompt.txt", |
| | repo_type="dataset" |
| | ) |
| | with open(file_path, encoding="utf-8") as f: |
| | return f.read() |
| | except Exception: |
| | file_path = hf_hub_download( |
| | repo_id="tomo2chin2/GURAREKOstlyle", |
| | filename="prompt.txt", |
| | repo_type="dataset" |
| | ) |
| | with open(file_path, encoding="utf-8") as f: |
| | return f.read() |
| |
|
| | |
| | def generate_html_from_text(text, temperature=0.3, style="standard"): |
| | """Gemini で HTML を生成。2.5 Flash Preview の場合は thinking_off""" |
| | api_key = os.environ.get("GEMINI_API_KEY") |
| | if not api_key: |
| | raise ValueError("GEMINI_API_KEY が未設定") |
| | model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro") |
| |
|
| | client = genai.Client(api_key=api_key) |
| | system_instruction = load_system_instruction(style) |
| | prompt = f"{system_instruction}\n\n{text}" |
| |
|
| | safety_settings = [ |
| | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, |
| | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, |
| | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, |
| | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, |
| | ] |
| |
|
| | |
| | if model_name == "gemini-2.5-flash-preview-04-17": |
| | generation_cfg = types.GenerateContentConfig( |
| | temperature=temperature, |
| | top_p=0.7, |
| | top_k=20, |
| | max_output_tokens=8192, |
| | candidate_count=1, |
| | thinking_config=types.ThinkingConfig(thinking_budget=0) |
| | ) |
| | else: |
| | generation_cfg = types.GenerateContentConfig( |
| | temperature=temperature, |
| | top_p=0.7, |
| | top_k=20, |
| | max_output_tokens=8192, |
| | candidate_count=1, |
| | ) |
| |
|
| | response = client.models.generate_content( |
| | model=model_name, |
| | contents=prompt, |
| | config=generation_cfg, |
| | safety_settings=safety_settings |
| | ) |
| | raw = response.text |
| |
|
| | |
| | start = raw.find("```html") |
| | end = raw.rfind("```") |
| | if start != -1 and end != -1 and start < end: |
| | html_code = raw[start + 7:end].strip() |
| | return enhance_font_awesome_layout(html_code) |
| | logger.warning("```html``` ブロックが見つからず全文返却") |
| | return raw |
| |
|
| | |
| | def trim_image_whitespace(image, threshold=250, padding=10): |
| | gray = image.convert("L") |
| | data = list(gray.getdata()) |
| | w, h = gray.size |
| | pixels = [data[i * w:(i + 1) * w] for i in range(h)] |
| | min_x, min_y, max_x, max_y = w, h, 0, 0 |
| | for y in range(h): |
| | for x in range(w): |
| | if pixels[y][x] < threshold: |
| | min_x, min_y = min(min_x, x), min(min_y, y) |
| | max_x, max_y = max(max_x, x), max(max_y, y) |
| | if min_x > max_x or min_y > max_y: |
| | return image |
| | min_x = max(0, min_x - padding) |
| | min_y = max(0, min_y - padding) |
| | max_x = min(w - 1, max_x + padding) |
| | max_y = min(h - 1, max_y + padding) |
| | return image.crop((min_x, min_y, max_x + 1, max_y + 1)) |
| |
|
| | |
| | def render_fullpage_screenshot(html_code, extension_percentage=6.0, trim_whitespace=True): |
| | tmp_path, driver = None, None |
| | try: |
| | with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as f: |
| | tmp_path = f.name |
| | f.write(html_code) |
| | options = Options() |
| | options.add_argument("--headless") |
| | options.add_argument("--no-sandbox") |
| | options.add_argument("--disable-dev-shm-usage") |
| | driver = webdriver.Chrome(options=options) |
| | driver.set_window_size(1200, 1000) |
| | driver.get("file://" + tmp_path) |
| | WebDriverWait(driver, 15).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) |
| | time.sleep(3) |
| |
|
| | total_height = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)") |
| | viewport = driver.execute_script("return window.innerHeight") |
| | for i in range(max(1, total_height // viewport) + 1): |
| | driver.execute_script(f"window.scrollTo(0, {i * (viewport - 200)})") |
| | time.sleep(0.2) |
| | driver.execute_script("window.scrollTo(0, 0)") |
| | driver.execute_script("document.documentElement.style.overflow='hidden';document.body.style.overflow='hidden'") |
| | dims = driver.execute_script("return {w:document.documentElement.scrollWidth,h:document.documentElement.scrollHeight}") |
| | driver.set_window_size(dims["w"], int(dims["h"] * (1 + extension_percentage / 100))) |
| | time.sleep(1) |
| | png = driver.get_screenshot_as_png() |
| | img = Image.open(BytesIO(png)) |
| | return trim_image_whitespace(img, 248, 20) if trim_whitespace else img |
| | except Exception as e: |
| | logger.error(f"Selenium error: {e}", exc_info=True) |
| | return Image.new("RGB", (1, 1)) |
| | finally: |
| | if driver: |
| | driver.quit() |
| | if tmp_path and os.path.exists(tmp_path): |
| | os.remove(tmp_path) |
| |
|
| | |
| | def text_to_screenshot(text, ext, temp=0.3, trim=True, style="standard"): |
| | html = generate_html_from_text(text, temp, style) |
| | return render_fullpage_screenshot(html, ext, trim) |
| |
|
| | |
| | app = FastAPI() |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], allow_credentials=True, |
| | allow_methods=["*"], allow_headers=["*"], |
| | ) |
| |
|
| | |
| | gradio_dir = os.path.dirname(gr.__file__) |
| | app.mount("/static", StaticFiles(directory=os.path.join(gradio_dir, "templates/frontend/static")), name="static") |
| |
|
| | |
| | @app.post("/api/screenshot", response_class=StreamingResponse) |
| | async def api_render(req: ScreenshotRequest): |
| | img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace) |
| | buf = BytesIO() |
| | img.save(buf, format="PNG") |
| | buf.seek(0) |
| | return StreamingResponse(buf, media_type="image/png") |
| |
|
| | @app.post("/api/text-to-screenshot", response_class=StreamingResponse) |
| | async def api_text_screenshot(req: GeminiRequest): |
| | img = text_to_screenshot(req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style) |
| | buf = BytesIO() |
| | img.save(buf, format="PNG") |
| | buf.seek(0) |
| | return StreamingResponse(buf, media_type="image/png") |
| |
|
| | |
| | def process_input(mode, txt, ext, temp, trim, style): |
| | if mode == "HTML入力": |
| | return render_fullpage_screenshot(txt, ext, trim) |
| | return text_to_screenshot(txt, ext, temp, trim, style) |
| |
|
| | |
| | with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Base()) as iface: |
| | |
| | gr.Markdown( |
| | "<h1 style='text-align:center;margin:0.2em 0'>HTMLビューア & テキスト→インフォグラフィック変換</h1>", |
| | elem_id="title", |
| | inline=True, |
| | ) |
| | gr.Markdown( |
| | "HTML を直接レンダリングするか、テキストを Gemini API でインフォグラフィックに変換して画像取得できます。", |
| | elem_id="subtitle", |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | input_mode = gr.Radio( |
| | ["HTML入力", "テキスト入力"], |
| | value="HTML入力", |
| | label="入力モード", |
| | ) |
| |
|
| | |
| | input_text = gr.Textbox( |
| | lines=15, |
| | label="入力", |
| | placeholder="HTMLコードまたはテキストを入力してください(モードに応じて処理)。", |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | |
| | with gr.Column(scale=1, min_width=180): |
| | style_dropdown = gr.Dropdown( |
| | choices=["standard", "cute", "resort", "cool", "dental"], |
| | value="standard", |
| | label="デザインスタイル", |
| | info="テキスト→HTML 変換時のテーマ", |
| | visible=False, |
| | ) |
| | |
| | with gr.Column(scale=3): |
| | extension_percentage = gr.Slider( |
| | 0, 30, value=10, step=1, |
| | label="上下高さ拡張率(%)" |
| | ) |
| | temperature = gr.Slider( |
| | 0.0, 1.0, value=0.5, step=0.1, |
| | label="生成時の温度(低い=一貫性高、高い=創造性高)", |
| | visible=False, |
| | ) |
| |
|
| | |
| | trim_whitespace = gr.Checkbox( |
| | value=True, |
| | label="余白を自動トリミング", |
| | ) |
| | submit_btn = gr.Button("生成", variant="primary", size="lg") |
| |
|
| | |
| | output_image = gr.Image( |
| | type="pil", |
| | label="ページ全体のスクリーンショット", |
| | show_label=True, |
| | show_download_button=True, |
| | ) |
| |
|
| | |
| | def toggle_controls(mode): |
| | """テキストモードのときだけ温度とスタイルを表示""" |
| | is_text = mode == "テキスト入力" |
| | return ( |
| | gr.update(visible=is_text), |
| | gr.update(visible=is_text), |
| | ) |
| |
|
| | input_mode.change( |
| | fn=toggle_controls, |
| | inputs=input_mode, |
| | outputs=[temperature, style_dropdown], |
| | ) |
| |
|
| | submit_btn.click( |
| | fn=process_input, |
| | inputs=[ |
| | input_mode, |
| | input_text, |
| | extension_percentage, |
| | temperature, |
| | trim_whitespace, |
| | style_dropdown, |
| | ], |
| | outputs=output_image, |
| | ) |
| |
|
| | |
| | gr.Markdown( |
| | f""" |
| | **使用モデル** : `{os.getenv("GEMINI_MODEL", "gemini-1.5-pro")}` |
| | **API** : `/api/screenshot` / `/api/text-to-screenshot` |
| | """, |
| | elem_id="footnote", |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, iface, path="/") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|