HTMLviewer2_API / app.py
tomo2chin2's picture
Update app.py
45b1af9 verified
raw
history blame
9.83 kB
import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from PIL import Image
from io import BytesIO
import tempfile, time, os, logging
from huggingface_hub import hf_hub_download
# ---------- Gemini SDK (v1.x) ----------
from google import genai # :contentReference[oaicite:4]{index=4}
from google.genai import types
# ---------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------- Pydantic モデル ----------
class GeminiRequest(BaseModel):
text: str
extension_percentage: float = 10.0
temperature: float = 0.5
trim_whitespace: bool = True
style: str = "standard"
class ScreenshotRequest(BaseModel):
html_code: str
extension_percentage: float = 10.0
trim_whitespace: bool = True
# --------------------------------------
# ---------- ユーティリティ ----------
def enhance_font_awesome_layout(html_code: str) -> str:
fix_css = """
<style>[class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;}
h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{
vertical-align:middle!important;margin-right:10px!important;}
.fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far,span+.fab{
display:inline-block!important;margin-left:5px!important;}
li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;}</style>"""
if "<head>" in html_code:
return html_code.replace("</head>", f"{fix_css}</head>")
return f"<html><head>{fix_css}</head>{html_code}</html>"
def load_system_instruction(style="standard") -> str:
styles = ["standard", "cute", "resort", "cool", "dental"]
if style not in styles:
style = "standard"
local = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
if os.path.exists(local):
with open(local, encoding="utf-8") as f:
return f.read()
# HF fallback
file_path = hf_hub_download(
repo_id="tomo2chin2/GURAREKOstlyle",
filename=f"{style}/prompt.txt" if style != "standard" else "prompt.txt",
repo_type="dataset"
)
with open(file_path, encoding="utf-8") as f:
return f.read()
# --------------------------------------
# ---------- Gemini → HTML ----------
def generate_html_from_text(text, temperature=0.3, style="standard") -> str:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY 未設定")
model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-pro")
client = genai.Client(api_key=api_key)
prompt = f"{load_system_instruction(style)}\n\n{text}"
if model_name == "gemini-2.5-flash-preview-04-17": # thinking OFF :contentReference[oaicite:5]{index=5}
cfg = types.GenerateContentConfig(
temperature=temperature, top_p=0.7, top_k=20,
max_output_tokens=8192, candidate_count=1,
thinking_config=types.ThinkingConfig(thinking_budget=0)
)
else:
cfg = types.GenerateContentConfig(
temperature=temperature, top_p=0.7, top_k=20,
max_output_tokens=8192, candidate_count=1
)
raw = client.models.generate_content(
model=model_name,
contents=prompt,
config=cfg
).text
s, e = raw.find("```html"), raw.rfind("```")
if s != -1 and e != -1 and s < e:
html = raw[s + 7:e].strip()
return enhance_font_awesome_layout(html)
return raw
# --------------------------------------
# ---------- 画像トリミング ----------
def trim_image_whitespace(img: Image.Image, threshold=248, padding=20):
g = img.convert("L")
w, h = g.size
pix = list(g.getdata())
pix = [pix[i*w:(i+1)*w] for i in range(h)]
xs, ys = [w], [h]
xe = ye = -1
for y in range(h):
for x in range(w):
if pix[y][x] < threshold:
xs.append(x); ys.append(y); xe = max(xe, x); ye = max(ye, y)
if xe == -1:
return img
x0, y0 = max(0, min(xs)-padding), max(0, min(ys)-padding)
x1, y1 = min(w, xe+padding+1), min(h, ye+padding+1)
return img.crop((x0, y0, x1, y1))
# --------------------------------------
# ---------- Selenium スクショ ----------
def render_fullpage_screenshot(html, ext=6.0, trim=True) -> Image.Image:
tmp, driver = None, None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".html", mode="w", encoding="utf-8") as f:
f.write(html); tmp = f.name
opts = Options()
opts.add_argument("--headless=new") # 新 headless フラグ :contentReference[oaicite:6]{index=6}
opts.add_argument("--no-sandbox"); opts.add_argument("--disable-dev-shm-usage")
driver = webdriver.Chrome(options=opts)
driver.set_window_size(1200, 1000)
driver.get("file://" + tmp)
WebDriverWait(driver, 15).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
time.sleep(1)
total = driver.execute_script("return Math.max(document.body.scrollHeight,document.documentElement.scrollHeight)")
vp = driver.execute_script("return window.innerHeight")
for y in range(0, total, vp-200):
driver.execute_script(f"window.scrollTo(0,{y})"); time.sleep(0.1)
driver.execute_script("window.scrollTo(0,0)")
driver.execute_script("document.documentElement.style.overflow='hidden'")
w = driver.execute_script("return document.documentElement.scrollWidth")
h = driver.execute_script("return document.documentElement.scrollHeight")
driver.set_window_size(w, int(h*(1+ext/100)))
time.sleep(0.5)
img = Image.open(BytesIO(driver.get_screenshot_as_png()))
return trim_image_whitespace(img) if trim else img
except Exception as e:
logger.error(e, exc_info=True)
return Image.new("RGB", (1,1))
finally:
if driver: driver.quit()
if tmp and os.path.exists(tmp): os.remove(tmp)
# --------------------------------------
def text_to_screenshot(text, ext, temp, trim, style):
html = generate_html_from_text(text, temp, style)
return render_fullpage_screenshot(html, ext, trim)
# ---------- FastAPI ----------
app = FastAPI()
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"],
allow_headers=["*"], allow_credentials=True
)
@app.post("/api/screenshot", response_class=StreamingResponse)
async def api_screen(req: ScreenshotRequest):
img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
@app.post("/api/text-to-screenshot", response_class=StreamingResponse)
async def api_text(req: GeminiRequest):
img = text_to_screenshot(req.text, req.extension_percentage,
req.temperature, req.trim_whitespace, req.style)
buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
# --------------------------------------
# ---------- Gradio UI ----------
def process(mode, txt, ext, temp, trim, style):
if mode == "HTML入力":
return render_fullpage_screenshot(txt, ext, trim)
return text_to_screenshot(txt, ext, temp, trim, style)
with gr.Blocks(theme=gr.themes.Base(), title="HTML Viewer & Text→Infographic") as demo:
gr.Markdown("## HTMLビューア & テキスト→インフォグラフィック変換") # central heading :contentReference[oaicite:7]{index=7}
with gr.Row(): # 横一列配置 :contentReference[oaicite:8]{index=8}
mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード")
with gr.Row(): # 入力パネル & 出力画像
with gr.Column(scale=5):
txt = gr.Textbox(lines=15, label="入力")
with gr.Row():
style_dd = gr.Dropdown(["standard","cute","resort","cool","dental"],
value="standard", label="デザインスタイル", visible=False)
temp_sl = gr.Slider(0,1,step=0.1,value=0.5,label="生成温度",visible=False)
ext_sl = gr.Slider(0,30,step=1,value=10,label="高さ拡張率(%)")
trim_cb = gr.Checkbox(value=True,label="余白トリミング")
gen_btn = gr.Button("生成", variant="primary")
with gr.Column(scale=7):
out_img = gr.Image(type="pil", label="プレビュー", height=540)
# モード切替で可視/不可視を更新
def _toggle(m): vis = m=="テキスト入力"; return [gr.update(visible=vis), gr.update(visible=vis)]
mode.change(_toggle, mode, [temp_sl, style_dd])
gen_btn.click(process, [mode, txt, ext_sl, temp_sl, trim_cb, style_dd], out_img)
gr.Markdown(
f"""
**使用モデル** : `{os.getenv('GEMINI_MODEL','gemini-1.5-pro')}`
`/api/screenshot` ・ `/api/text-to-screenshot`
"""
)
# ---------- マウント ----------
demo_app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(demo_app, host="0.0.0.0", port=7860)