File size: 18,168 Bytes
5a6269f
84f5d0d
 
 
 
 
5a6269f
697fed0
35a8897
697fed0
 
 
5a6269f
697fed0
 
5a6269f
697fed0
5a6269f
d9bd4cb
f39c339
697fed0
d9bd4cb
 
 
 
 
 
790339d
 
 
697fed0
d9bd4cb
84f5d0d
 
 
d9bd4cb
 
 
84f5d0d
 
 
4f99fb7
84f5d0d
697fed0
26eec55
4f99fb7
 
 
697fed0
 
26eec55
84f5d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
26eec55
 
4f99fb7
84f5d0d
4f99fb7
84f5d0d
4f99fb7
 
 
84f5d0d
5a6269f
84f5d0d
 
4f99fb7
697fed0
26eec55
84f5d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
697fed0
4f99fb7
84f5d0d
 
4f99fb7
84f5d0d
 
 
 
 
 
 
26eec55
84f5d0d
5a6269f
26eec55
4f99fb7
84f5d0d
 
 
d9bd4cb
 
697fed0
 
 
 
d9bd4cb
 
 
697fed0
 
 
 
64edcc1
35a8897
64edcc1
35a8897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64edcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f5d0d
790339d
84f5d0d
790339d
 
 
26eec55
790339d
 
 
 
 
 
 
 
 
 
abe973b
4882a8a
 
 
35a8897
abe973b
35a8897
4882a8a
790339d
 
 
 
 
 
 
 
 
84f5d0d
 
790339d
26eec55
 
84f5d0d
 
 
 
 
 
 
 
26eec55
5a6269f
84f5d0d
35a8897
84f5d0d
 
 
 
 
f39c339
697fed0
84f5d0d
 
 
35a8897
84f5d0d
 
 
35a8897
84f5d0d
 
35a8897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f5d0d
 
 
 
 
26eec55
84f5d0d
35a8897
 
 
 
 
 
 
 
 
84f5d0d
 
 
aeb0d3b
84f5d0d
 
aeb0d3b
84f5d0d
 
 
 
 
5a6269f
84f5d0d
35a8897
84f5d0d
35a8897
 
84f5d0d
 
 
 
 
 
35a8897
 
5a6269f
 
84f5d0d
5a6269f
26eec55
5a6269f
84f5d0d
 
 
 
 
 
 
 
35a8897
84f5d0d
 
 
 
 
 
 
 
35a8897
84f5d0d
 
 
5a6269f
 
84f5d0d
5a6269f
84f5d0d
 
35a8897
84f5d0d
 
 
 
 
 
350e08a
84f5d0d
 
 
9421aab
35a8897
 
84f5d0d
 
 
 
5a6269f
84f5d0d
 
 
5a6269f
84f5d0d
5a6269f
4882a8a
 
 
 
35a8897
4882a8a
84f5d0d
4882a8a
5a6269f
84f5d0d
 
 
 
 
 
 
5a6269f
84f5d0d
5a6269f
84f5d0d
 
 
5a6269f
35a8897
64edcc1
 
35a8897
 
 
 
 
64edcc1
84f5d0d
26eec55
84f5d0d
 
 
5a6269f
84f5d0d
4f99fb7
84f5d0d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# ===============================================================
# app.py  ― Gradio 5.x + FastAPI + Gemini + Selenium + 307対策
#   (1) FastAPI(redirect_slashes=False)               ←★追加
#   (2) Gradio を /gradio にマウント
#   (3) / と /gradio を /gradio/ へリダイレクト     ←★追加
#   それ以外は 5.x 対応フルロジックを一切カットせず
# ===============================================================

import os, time, tempfile, logging, threading, queue
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor

import numpy as np
from PIL import Image

import gradio as gr
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

# Updated: Use the new Google Genai library
from google import genai
from google.genai import types
from huggingface_hub import hf_hub_download

# ---------------------------------------------------------------
# ロギング
# ---------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------
# WebDriverPool  ― 以前提示した完全版と同一
# ---------------------------------------------------------------
class WebDriverPool:
    """再利用可能な WebDriver をプールして高速化"""
    def __init__(self, max_drivers: int = 3):
        self.driver_queue = queue.Queue()
        self.max_drivers = max_drivers
        self.lock = threading.Lock()
        self.count = 0
        logger.info(f"WebDriver プール初期化: 最大 {max_drivers}")

    def _create_driver(self):
        options = Options()
        options.add_argument("--headless")
        options.add_argument("--no-sandbox")
        options.add_argument("--disable-dev-shm-usage")
        options.add_argument("--force-device-scale-factor=1")
        options.add_argument("--disable-features=NetworkService")
        options.add_argument("--dns-prefetch-disable")

        chromedriver_path = os.environ.get("CHROMEDRIVER_PATH")
        if chromedriver_path and os.path.exists(chromedriver_path):
            logger.info(f"CHROMEDRIVER_PATH 使用: {chromedriver_path}")
            service = webdriver.ChromeService(executable_path=chromedriver_path)
            return webdriver.Chrome(service=service, options=options)
        return webdriver.Chrome(options=options)

    def get_driver(self):
        if not self.driver_queue.empty():
            logger.info("既存 WebDriver 取得")
            return self.driver_queue.get()

        with self.lock:
            if self.count < self.max_drivers:
                self.count += 1
                logger.info(f"新規 WebDriver 作成 ({self.count}/{self.max_drivers})")
                return self._create_driver()

        logger.info("プール満杯、空き待機中…")
        return self.driver_queue.get()

    def release_driver(self, driver):
        if driver:
            try:
                driver.get("about:blank")
                driver.execute_script("""
                    document.documentElement.style.overflow='';
                    document.body.style.overflow='';
                """)
                self.driver_queue.put(driver)
                logger.info("WebDriver をプールに返却")
            except Exception as e:
                logger.error(f"返却エラー: {e}")
                driver.quit()
                with self.lock:
                    self.count -= 1

    def close_all(self):
        logger.info("プール全 WebDriver 終了")
        closed = 0
        while not self.driver_queue.empty():
            try:
                drv = self.driver_queue.get(block=False)
                drv.quit(); closed += 1
            except queue.Empty:
                break
            except Exception as e:
                logger.error(f"終了エラー: {e}")
        with self.lock: self.count = 0
        logger.info(f"{closed} 個終了")

driver_pool = WebDriverPool(max_drivers=int(os.getenv("MAX_WEBDRIVERS", "3")))

# ---------------------------------------------------------------
# Pydantic モデル
# ---------------------------------------------------------------
class GeminiRequest(BaseModel):
    text: str
    extension_percentage: float = 10.0
    temperature: float = 0.5
    trim_whitespace: bool = True
    style: str = "standard"

class ScreenshotRequest(BaseModel):
    html_code: str
    extension_percentage: float = 10.0
    trim_whitespace: bool = True
    style: str = "standard"

# ---------------------------------------------------------------
# 補助関数(FontAwesome レイアウト / prompt 読み込み / Gemini 生成)
# ---------------------------------------------------------------
def enhance_font_awesome_layout(html_code: str) -> str:
    fa_preload = """
    <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
    <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
    <link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-brands-400.woff2" as="font" type="font/woff2" crossorigin>
    """
    fa_css = """
    <style>
      [class*="fa-"]{display:inline-block!important;margin-right:8px!important;vertical-align:middle!important;}
      h1 [class*="fa-"],h2 [class*="fa-"],h3 [class*="fa-"],h4 [class*="fa-"],h5 [class*="fa-"],h6 [class*="fa-"]{vertical-align:middle!important;margin-right:10px!important;}
      .fa+span,.fas+span,.far+span,.fab+span,span+.fa,span+.fas,span+.far+span{display:inline-block!important;margin-left:5px!important;}
      .card [class*="fa-"],.card-body [class*="fa-"]{float:none!important;clear:none!important;position:relative!important;}
      li [class*="fa-"],p [class*="fa-"]{margin-right:10px!important;}
      .inline-icon{display:inline-flex!important;align-items:center!important;justify-content:flex-start!important;}
      [class*="fa-"]+span{display:inline-block!important;vertical-align:middle!important;}
    </style>
    """
    if '<head>' in html_code:
        return html_code.replace('</head>', f'{fa_preload}{fa_css}</head>')
    return f'<html><head>{fa_preload}{fa_css}</head>{html_code}</html>'

# シンプルなプロンプトキャッシュを実装
_prompt_cache = {}

def load_system_instruction(style="standard") -> str:
    # キャッシュに存在すればそれを返す
    if style in _prompt_cache:
        return _prompt_cache[style]
    
    valid_styles = ["standard","cute","resort","cool","dental","school","KOKUGO"]
    if style not in valid_styles:
        style = "standard"
    local = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
    if os.path.exists(local):
        prompt_text = open(local, encoding="utf-8").read()
    else:
        try:
            f = hf_hub_download("tomo2chin2/GURAREKOstlyle", f"{style}/prompt.txt", repo_type="dataset")
            prompt_text = open(f, encoding="utf-8").read()
        except Exception:
            f = hf_hub_download("tomo2chin2/GURAREKOstlyle", "prompt.txt", repo_type="dataset")
            prompt_text = open(f, encoding="utf-8").read()
    
    # キャッシュに保存
    _prompt_cache[style] = prompt_text
    return prompt_text

def generate_html_from_text(text: str, temperature=0.5, style="standard") -> str:
    # Updated: Use the new Google Genai client API
    api_key = os.environ["GEMINI_API_KEY"]
    client = genai.Client(api_key=api_key)
    
    model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-pro")
    prompt = f"{load_system_instruction(style)}\n\n{text}"
    
    # Configure generation parameters
    config = types.GenerateContentConfig(
        temperature=temperature,
        top_p=0.7,
        top_k=20,
        max_output_tokens=8192,
        candidate_count=1
    )
    
    # Gemini 2.5 Flash Preview モデルの場合の特別設定
    if model_name == "gemini-2.5-flash-preview-04-17":
        logger.info("gemini-2.5-flash-preview-04-17 モデル検出: 思考モードをオフに設定")
        config.thinking_config = types.ThinkingConfig(thinking_budget=0)
        # max_output_tokens を 50000 に拡張 (唯一追加した最適化)
        logger.info("gemini-2.5-flash-preview-04-17 モデル検出: max_output_tokens を 50000 に設定")
        config.max_output_tokens = 10000
    
    # Generate content
    response = client.models.generate_content(
        model=model_name,
        contents=prompt,
        config=config
    )
    
    # Extract HTML from response
    raw = response.text
    s, e = raw.find("```html"), raw.rfind("```")
    html = raw[s+7:e].strip() if s != -1 and e != -1 else raw
    
    return enhance_font_awesome_layout(html)

def trim_image_whitespace(img: Image.Image, threshold=248, padding=20) -> Image.Image:
    arr = np.array(img.convert("L"))
    mask = arr < threshold
    if np.any(mask):
        ys, xs = np.where(mask.any(1))[0], np.where(mask.any(0))[0]
        return img.crop((max(xs[0]-padding,0), max(ys[0]-padding,0),
                         min(xs[-1]+padding, img.width-1),
                         min(ys[-1]+padding, img.height-1)))
    return img

# ---------------------------------------------------------------
# HTML → スクショ  (完全版ロジック)
# ---------------------------------------------------------------
def render_fullpage_screenshot(html_code: str, extension_percentage=6.0,
                               trim_whitespace=True, driver=None) -> Image.Image:
    tmp_path = None
    from_pool = False
    try:
        if driver is None:
            driver = driver_pool.get_driver()
            from_pool = True

        # HTML 保存
        with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode="w", encoding="utf-8") as tmp:
            tmp_path = tmp.name
            tmp.write(html_code)

        driver.set_window_size(1200, 1000)
        driver.get("file://" + tmp_path)

        # body 出現を待機
        WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))

        # リソースロード確認ループ(詳細ロジックは元コード準拠)
        max_wait, inc, waited = 5, 0.2, 0.0
        while waited < max_wait:
            state = driver.execute_script("""
                return {complete: document.readyState==='complete',
                        imgs: document.images.length,
                        loaded: Array.from(document.images).filter(i=>i.complete).length};
            """)
            if state['complete'] and (state['imgs']==0 or state['imgs']==state['loaded']):
                break
            time.sleep(inc); waited += inc

        # スクロールレンダリング
        total_h = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)")
        vh = driver.execute_script("return window.innerHeight")
        for i in range(max(1, min(5, total_h // vh))):
            driver.execute_script(f"window.scrollTo(0, {(vh-100)*i})")
            time.sleep(0.1)
        driver.execute_script("window.scrollTo(0,0)"); time.sleep(0.2)

        dims = driver.execute_script("""
            return {w: Math.max(document.body.scrollWidth, document.documentElement.scrollWidth),
                    h: Math.max(document.body.scrollHeight, document.documentElement.scrollHeight)}
        """)
        w = min(max(dims['w'], 100), 2000)
        h = min(max(dims['h'], 100), 4000)
        h = int(h * (1 + extension_percentage / 100.0))
        driver.set_window_size(w, h); time.sleep(0.5)

        img = Image.open(BytesIO(driver.get_screenshot_as_png()))
        return trim_image_whitespace(img, padding=20) if trim_whitespace else img

    except Exception as e:
        logger.error(f"Screenshot error: {e}", exc_info=True)
        return Image.new("RGB", (1, 1), (0, 0, 0))
    finally:
        if from_pool:
            driver_pool.release_driver(driver)
        if tmp_path and os.path.exists(tmp_path):
            try: os.remove(tmp_path)
            except Exception: pass

# ---------------------------------------------------------------
# テキスト → スクショ (並列 API 呼び出し + ドライバ確保)
# ---------------------------------------------------------------
def text_to_screenshot_parallel(text, ext_perc, temp=0.5, trim_ws=True, style="standard") -> Image.Image:
    with ThreadPoolExecutor(max_workers=2) as exe:
        html_future = exe.submit(generate_html_from_text, text, temp, style)
        driver_future = exe.submit(driver_pool.get_driver)
        html_code = html_future.result()
        driver = driver_future.result()
    return render_fullpage_screenshot(html_code, ext_perc, trim_ws, driver)

def text_to_screenshot(*args, **kwargs):
    return text_to_screenshot_parallel(*args, **kwargs)

# ===============================================================
# FastAPI  (★ redirect_slashes=False で自動 307 を殺す)
# ===============================================================
app = FastAPI(redirect_slashes=False)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# -------- API エンドポイントはそのまま --------
@app.post("/api/screenshot", response_class=StreamingResponse, tags=["Screenshot"])
async def api_render_screenshot(req: ScreenshotRequest):
    img = render_fullpage_screenshot(req.html_code, req.extension_percentage, req.trim_whitespace)
    buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
    return StreamingResponse(buf, media_type="image/png")

@app.post("/api/text-to-screenshot", response_class=StreamingResponse, tags=["Gemini","Screenshot"])
async def api_text_to_screenshot(req: GeminiRequest):
    img = text_to_screenshot_parallel(
        req.text, req.extension_percentage, req.temperature, req.trim_whitespace, req.style)
    buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
    return StreamingResponse(buf, media_type="image/png")

# ===============================================================
# Gradio UI (完全版 UI 定義)
# ===============================================================
def process_input(mode, text, ext, temp, trim, style):
    return render_fullpage_screenshot(text, ext, trim) if mode == "HTML入力" else \
           text_to_screenshot_parallel(text, ext, temp, trim, style)

with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Origin()) as iface:
    gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
    with gr.Row():
        mode = gr.Radio(["HTML入力", "テキスト入力"], value="HTML入力", label="入力モード")
    text = gr.Textbox(lines=15, label="入力")
    with gr.Row():
        style = gr.Dropdown(
            ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
            value="standard", label="デザインスタイル", visible=False)
        with gr.Column(scale=2):
            ext = gr.Slider(0, 30, value=15, step=1, label="上下高さ拡張率(%)")
            temp = gr.Slider(0.0, 1.0, value=1.0, step=0.1,
                             label="生成時の温度", visible=False)
    trim = gr.Checkbox(value=True, label="余白を自動トリミング")
    btn = gr.Button("生成")
    out = gr.Image(type="pil", label="スクリーンショット")

    # 可視制御
    def toggle(m): vis = m == "テキスト入力"; return [gr.update(visible=vis), gr.update(visible=vis)]
    mode.change(toggle, mode, [temp, style])

    btn.click(process_input, [mode, text, ext, temp, trim, style], out)

    # モデル名を表示し、Gemini 2.5の場合は思考モードのステータスも表示
    model_name = os.getenv('GEMINI_MODEL', 'gemini-1.5-pro')
    thinking_status = ""
    if model_name == "gemini-2.5-flash-preview-04-17":
        thinking_status = "(思考モード: オフ、最大トークン: 10000)"
    
    gr.Markdown(f"**API** `/api/screenshot`, `/api/text-to-screenshot` &nbsp;&nbsp; "
                f"使用モデル: `{model_name}` {thinking_status}")

# ===============================================================
# Gradio を /gradio にマウントし、明示リダイレクトを追加
# ===============================================================
GRADIO_PATH = "/gradio"
app = gr.mount_gradio_app(app, iface, path=GRADIO_PATH, ssr_mode=False)

# ルート → /gradio/ へ転送
@app.get("/")
def _root(): return RedirectResponse(GRADIO_PATH + "/")

# 末尾スラッシュ無し → 有りへ転送
@app.get(GRADIO_PATH)
def _no_slash(): return RedirectResponse(GRADIO_PATH + "/")

# 起動時に頻繁に使用するプロンプトを先読み
@app.on_event("startup")
async def startup_event():
    # 初期化は最小限に
    styles = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
    for style in styles:
        load_system_instruction(style)
    logger.info("システムプロンプトのキャッシュを準備完了")

# ===============================================================
# ローカルデバッグ
# ===============================================================
if __name__ == "__main__":
    import uvicorn
    logger.info("Uvicorn 起動 (ローカル)")
    uvicorn.run(app, host="0.0.0.0", port=7860)

# ===============================================================
# 終了時 WebDriver クリーンアップ
# ===============================================================
import atexit
atexit.register(driver_pool.close_all)