HTMLviewer_Dev / app.py
tomo2chin2's picture
Update app.py
42e5ffa verified
raw
history blame
31.8 kB
# app.py
import os
import time
import tempfile
import threading
import queue
import logging
import numpy as np # 追加: 画像処理の最適化用
from io import BytesIO
from PIL import Image
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Body
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from concurrent.futures import ThreadPoolExecutor # 追加: 並列処理用
from huggingface_hub import hf_hub_download
# 既存の Gemini ライブラリ
import google.generativeai as genai_old
# 新しい Gemini ライブラリ(2.5系モデル用)
from google import genai as genai_new
from google.genai import types
# ロギング設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- WebDriverプールの実装 ---
class WebDriverPool:
"""WebDriverインスタンスを再利用するためのプール"""
def __init__(self, max_drivers=3):
self.driver_queue = queue.Queue()
self.max_drivers = max_drivers
self.lock = threading.Lock()
self.count = 0
logger.info(f"WebDriverプールを初期化: 最大 {max_drivers} ドライバー")
def get_driver(self):
"""プールからWebDriverを取得、なければ新規作成"""
if not self.driver_queue.empty():
logger.info("既存のWebDriverをプールから取得")
return self.driver_queue.get()
with self.lock:
if self.count < self.max_drivers:
self.count += 1
logger.info(f"新しいWebDriverを作成 (合計: {self.count}/{self.max_drivers})")
options = Options()
options.add_argument("--headless")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--force-device-scale-factor=1")
options.add_argument("--disable-features=NetworkService")
options.add_argument("--dns-prefetch-disable")
webdriver_path = os.environ.get("CHROMEDRIVER_PATH")
if webdriver_path and os.path.exists(webdriver_path):
logger.info(f"CHROMEDRIVER_PATH使用: {webdriver_path}")
service = webdriver.ChromeService(executable_path=webdriver_path)
return webdriver.Chrome(service=service, options=options)
else:
logger.info("デフォルトのChromeDriverを使用")
return webdriver.Chrome(options=options)
logger.info("WebDriverプールがいっぱいです。利用可能なドライバーを待機中...")
return self.driver_queue.get()
def release_driver(self, driver):
"""ドライバーをプールに戻す"""
if driver:
try:
driver.get("about:blank")
driver.execute_script("""
document.documentElement.style.overflow = '';
document.body.style.overflow = '';
""")
self.driver_queue.put(driver)
logger.info("WebDriverをプールに戻しました")
except Exception as e:
logger.error(f"ドライバーをプールに戻す際にエラー: {e}")
driver.quit()
with self.lock:
self.count -= 1
def close_all(self):
"""全てのドライバーを終了"""
logger.info("WebDriverプールを終了します")
closed = 0
while not self.driver_queue.empty():
try:
driver = self.driver_queue.get(block=False)
driver.quit()
closed += 1
except queue.Empty:
break
except Exception as e:
logger.error(f"ドライバー終了中にエラー: {e}")
logger.info(f"{closed}個のWebDriverを終了しました")
with self.lock:
self.count = 0
# グローバルなWebDriverプールを作成
driver_pool = WebDriverPool(max_drivers=int(os.environ.get("MAX_WEBDRIVERS", "3")))
# --- リクエストモデル ---
class GeminiRequest(BaseModel):
"""Geminiへのリクエストデータモデル"""
text: str
extension_percentage: float = 10.0 # デフォルト値10%
temperature: float = 0.5 # デフォルト値を0.5
trim_whitespace: bool = True # 余白トリミングオプション(既定で有効)
style: str = "standard" # デフォルトはstandard
class ScreenshotRequest(BaseModel):
"""スクリーンショットリクエストモデル"""
html_code: str
extension_percentage: float = 10.0
trim_whitespace: bool = True
style: str = "standard"
# --- Font Awesomeレイアウト改善 ---
def enhance_font_awesome_layout(html_code):
"""Font Awesomeレイアウトを改善し、プリロードタグを追加"""
fa_preload = """
<link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-solid-900.woff2" as="font" type="font/woff2" crossorigin>
<link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-regular-400.woff2" as="font" type="font/woff2" crossorigin>
<link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/webfonts/fa-brands-400.woff2" as="font" type="font/woff2" crossorigin>
"""
fa_fix_css = """
<style>
[class*="fa-"] {
display: inline-block !important;
margin-right: 8px !important;
vertical-align: middle !important;
}
h1 [class*="fa-"], h2 [class*="fa-"], h3 [class*="fa-"],
h4 [class*="fa-"], h5 [class*="fa-"], h6 [class*="fa-"] {
vertical-align: middle !important;
margin-right: 10px !important;
}
.fa + span, .fas + span, .far + span, .fab + span,
span + .fa, span + .fas, span + .far + span {
display: inline-block !important;
margin-left: 5px !important;
}
.card [class*="fa-"], .card-body [class*="fa-"] {
float: none !important;
clear: none !important;
position: relative !important;
}
li [class*="fa-"], p [class*="fa-"] {
margin-right: 10px !important;
}
.inline-icon {
display: inline-flex !important;
align-items: center !important;
justify-content: flex-start !important;
}
[class*="fa-"] + span {
display: inline-block !important;
vertical-align: middle !important;
}
</style>
"""
if '<head>' in html_code:
return html_code.replace('</head>', f'{fa_preload}{fa_fix_css}</head>')
elif '<html' in html_code:
head_end = html_code.find('</head>')
if head_end > 0:
return html_code[:head_end] + fa_preload + fa_fix_css + html_code[head_end:]
else:
body_start = html_code.find('<body')
if body_start > 0:
return html_code[:body_start] + f'<head>{fa_preload}{fa_fix_css}</head>' + html_code[body_start:]
return f'<html><head>{fa_preload}{fa_fix_css}</head>' + html_code + '</html>'
# --- システムインストラクション読み込み ---
def load_system_instruction(style="standard"):
"""
指定されたスタイルのシステムインストラクションを読み込む
"""
valid_styles = ["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"]
if style not in valid_styles:
logger.warning(f"無効なスタイル '{style}' が指定されました。デフォルトの 'standard' を使用します。")
style = "standard"
logger.info(f"スタイル '{style}' のシステムインストラクションを読み込みます")
# ローカルファイル優先
local_path = os.path.join(os.path.dirname(__file__), style, "prompt.txt")
if os.path.exists(local_path):
logger.info(f"ローカルファイルを使用: {local_path}")
with open(local_path, 'r', encoding='utf-8') as file:
return file.read()
# HuggingFace から取得
try:
file_path = hf_hub_download(
repo_id="tomo2chin2/GURAREKOstlyle",
filename=f"{style}/prompt.txt",
repo_type="dataset"
)
logger.info(f"HuggingFace から読み込み: {file_path}")
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
except Exception as style_error:
logger.warning(f"スタイル '{style}' の読み込み失敗: {style_error}")
logger.info("デフォルトの prompt.txt を読み込みます")
file_path = hf_hub_download(
repo_id="tomo2chin2/GURAREKOstlyle",
filename="prompt.txt",
repo_type="dataset"
)
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
# --- テキストからHTML生成 ---
def generate_html_from_text(text, temperature=0.5, style="standard"):
"""
テキストからHTMLを生成する
gemini-2.5-flash-preview-04-17 のときのみ新ライブラリ+thinkingBudget=0
"""
try:
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
logger.error("GEMINI_API_KEY 環境変数が設定されていません")
raise ValueError("GEMINI_API_KEY が設定されていません")
model_name = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
logger.info(f"使用する Gemini モデル: {model_name}")
if model_name == "gemini-2.5-flash-preview-04-17":
# 新ライブラリ(genai_new)を使用し thinkingBudget=0 を設定
client = genai_new.Client(api_key=api_key)
logger.info("新ライブラリ genai_new を使用 (thinkingBudget=0)")
cfg = types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
)
response = client.models.generate_content(
model=model_name,
contents=text,
config=cfg
)
raw = response.text
else:
# 既存ライブラリ(genai_old)のまま
genai_old.configure(api_key=api_key)
system_instruction = load_system_instruction(style)
prompt = f"{system_instruction}\n\n{text}"
response = genai_old.GenerativeModel(model_name).generate_content(
prompt,
generation_config={
"temperature": temperature,
"top_p": 0.7,
"top_k": 20,
"max_output_tokens": 8192,
"candidate_count": 1
},
safety_settings=[
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
)
raw = response.text
# Markdown ```html``` 部分を取り出す
html_start = raw.find("```html")
html_end = raw.rfind("```")
if html_start != -1 and html_end != -1 and html_start < html_end:
html_code = raw[html_start + 7:html_end].strip()
else:
html_code = raw
# Font Awesome レイアウト最適化
html_code = enhance_font_awesome_layout(html_code)
logger.info("Font Awesome レイアウトの最適化を適用しました")
return html_code
except Exception as e:
logger.error(f"HTML生成中にエラー: {e}", exc_info=True)
raise Exception(f"Gemini API での HTML 生成に失敗しました: {e}")
# --- 画像トリミング ---
def trim_image_whitespace(image, threshold=250, padding=10):
try:
gray = image.convert('L')
np_image = np.array(gray)
mask = np_image < threshold
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if np.any(rows) and np.any(cols):
row_indices = np.where(rows)[0]
col_indices = np.where(cols)[0]
min_y, max_y = row_indices[0], row_indices[-1]
min_x, max_x = col_indices[0], col_indices[-1]
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(image.width - 1, max_x + padding)
max_y = min(image.height - 1, max_y + padding)
trimmed = image.crop((min_x, min_y, max_x + 1, max_y + 1))
logger.info(f"画像をトリミングしました: {image.width}x{image.height}{trimmed.width}x{trimmed.height}")
return trimmed
logger.warning("トリミング領域が見つかりません。元の画像を返します。")
return image
except Exception as e:
logger.error(f"画像トリミング中にエラー: {e}", exc_info=True)
return image
# --- スクリーンショット生成 ---
def render_fullpage_screenshot(html_code: str, extension_percentage: float = 6.0,
trim_whitespace: bool = True, driver=None) -> Image.Image:
tmp_path = None
driver_from_pool = False
if driver is None:
driver = driver_pool.get_driver()
driver_from_pool = True
logger.info("WebDriverプールからドライバーを取得しました")
try:
# HTML を一時ファイルに保存
with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode='w', encoding='utf-8') as tmp_file:
tmp_path = tmp_file.name
tmp_file.write(html_code)
logger.info(f"HTML saved to temporary file: {tmp_path}")
# ウィンドウ初期サイズ設定
initial_width = 1200
initial_height = 1000
driver.set_window_size(initial_width, initial_height)
driver.get("file://" + tmp_path)
# body 要素の読み込み待機
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
# リソース読み込み待機
max_wait = 5
wait_increment = 0.2
wait_time = 0
while wait_time < max_wait:
state = driver.execute_script("""
return {
complete: document.readyState === 'complete',
imgCount: document.images.length,
imgLoaded: Array.from(document.images).filter(img => img.complete).length,
faElements: document.querySelectorAll('.fa, .fas, .far, .fab, [class*="fa-"]').length
};
""")
if state['complete'] and (state['imgCount'] == 0 or state['imgLoaded'] == state['imgCount']):
break
time.sleep(wait_increment)
wait_time += wait_increment
# Font Awesome 要素が多い場合は少し待機
if state.get('faElements', 0) > 30:
time.sleep(min(1.0, state['faElements'] / 100))
# スクロール処理
total_height = driver.execute_script("return Math.max(document.body.scrollHeight, document.documentElement.scrollHeight);")
viewport_height = driver.execute_script("return window.innerHeight;")
scrolls_needed = max(1, min(5, total_height // viewport_height))
for i in range(scrolls_needed):
scroll_pos = i * (viewport_height - 100)
driver.execute_script(f"window.scrollTo(0, {scroll_pos});")
time.sleep(0.1)
driver.execute_script("window.scrollTo(0, 0);")
time.sleep(0.2)
# スクロールバー非表示
driver.execute_script("""
document.documentElement.style.overflow = 'hidden';
document.body.style.overflow = 'hidden';
""")
# ページ寸法取得
dims = driver.execute_script("""
return {
width: Math.max(
document.documentElement.scrollWidth,
document.documentElement.offsetWidth,
document.documentElement.clientWidth,
document.body ? document.body.scrollWidth : 0,
document.body ? document.body.offsetWidth : 0,
document.body ? document.body.clientWidth : 0
),
height: Math.max(
document.documentElement.scrollHeight,
document.documentElement.offsetHeight,
document.documentElement.clientHeight,
document.body ? document.body.scrollHeight : 0,
document.body ? document.body.offsetHeight : 0,
document.body ? document.body.clientHeight : 0
)
};
""")
scroll_width = max(dims['width'], 100)
scroll_height = max(dims['height'], 100)
scroll_width = min(scroll_width, 2000)
scroll_height = min(scroll_height, 4000)
# 縦余白追加
adjusted_height = int(scroll_height * (1 + extension_percentage / 100.0))
adjusted_height = max(adjusted_height, scroll_height, 100)
driver.set_window_size(scroll_width, adjusted_height)
time.sleep(0.5)
# スクリーンショット取得
png = driver.get_screenshot_as_png()
img = Image.open(BytesIO(png))
logger.info(f"Screenshot dimensions: {img.width}x{img.height}")
# 余白トリミング
if trim_whitespace:
img = trim_image_whitespace(img, threshold=248, padding=20)
logger.info(f"Trimmed dimensions: {img.width}x{img.height}")
return img
except Exception as e:
logger.error(f"Error during screenshot generation: {e}", exc_info=True)
return Image.new('RGB', (1, 1), color=(0, 0, 0))
finally:
if driver_from_pool:
driver_pool.release_driver(driver)
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
logger.info(f"Temporary file {tmp_path} removed.")
except Exception as e:
logger.error(f"Error removing temporary file {tmp_path}: {e}")
# --- 並列処理版スクリーンショット生成 ---
def text_to_screenshot_parallel(text: str, extension_percentage: float, temperature: float = 0.5,
trim_whitespace: bool = True, style: str = "standard") -> Image.Image:
start_time = time.time()
drv = None
tmp_path = None
driver_from_pool = False
try:
with ThreadPoolExecutor(max_workers=2) as executor:
html_future = executor.submit(generate_html_from_text, text, temperature, style)
driver_future = executor.submit(driver_pool.get_driver)
html_code = html_future.result()
drv = driver_future.result()
driver_from_pool = True
# HTML→一時ファイル
with tempfile.NamedTemporaryFile(suffix=".html", delete=False, mode='w', encoding='utf-8') as tmp_file:
tmp_path = tmp_file.name
tmp_file.write(html_code)
logger.info(f"HTMLを一時ファイルに保存: {tmp_path}")
# ドライバ初期化
drv.set_window_size(1200, 1000)
drv.get("file://" + tmp_path)
WebDriverWait(drv, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
# リソース待機
max_wait = 3
wait_increment = 0.2
wait_time = 0
while wait_time < max_wait:
state = drv.execute_script("""
return {
complete: document.readyState==='complete',
imgCount: document.images.length,
imgLoaded: Array.from(document.images).filter(img=>img.complete).length,
faElements: document.querySelectorAll('.fa, .fas, .far, .fab, [class*="fa-"]').length
};
""")
if state['complete'] and (state['imgCount']==0 or state['imgLoaded']==state['imgCount']):
break
time.sleep(wait_increment)
wait_time += wait_increment
if state.get('faElements', 0) > 30:
time.sleep(min(1.0, state['faElements'] / 100))
# 簡易スクロール
drv.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(0.2)
drv.execute_script("window.scrollTo(0, 0);")
time.sleep(0.2)
drv.execute_script("document.documentElement.style.overflow='hidden';document.body.style.overflow='hidden';")
# 寸法取得
dims = drv.execute_script("""
return {
width: Math.max(
document.documentElement.scrollWidth,
document.documentElement.offsetWidth,
document.documentElement.clientWidth,
document.body ? document.body.scrollWidth : 0,
document.body ? document.body.offsetWidth : 0,
document.body ? document.body.clientWidth : 0
),
height: Math.max(
document.documentElement.scrollHeight,
document.documentElement.offsetHeight,
document.documentElement.clientHeight,
document.body ? document.body.scrollHeight : 0,
document.body ? document.body.offsetHeight : 0,
document.body ? document.body.clientHeight : 0
)
};
""")
w = max(dims['width'], 100)
h = max(dims['height'], 100)
w = min(w, 2000)
h = min(h, 4000)
adjusted_h = int(h * (1 + extension_percentage / 100.0))
adjusted_h = max(adjusted_h, h, 100)
drv.set_window_size(w, adjusted_h)
time.sleep(0.2)
# スクリーンショット取得
png = drv.get_screenshot_as_png()
img = Image.open(BytesIO(png))
if trim_whitespace:
img = trim_image_whitespace(img, threshold=248, padding=20)
elapsed = time.time() - start_time
logger.info(f"Parallel generation 完了 (所要時間: {elapsed:.2f}秒)")
return img
except Exception as e:
logger.error(f"Parallel generation error: {e}", exc_info=True)
return Image.new('RGB', (1, 1), color=(0, 0, 0))
finally:
if driver_from_pool and drv:
driver_pool.release_driver(drv)
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except Exception:
pass
# --- レガシー版スクリーンショット生成 ---
def text_to_screenshot(text: str, extension_percentage: float, temperature: float = 0.3,
trim_whitespace: bool = True, style: str = "standard") -> Image.Image:
return text_to_screenshot_parallel(text, extension_percentage, temperature, trim_whitespace, style)
# --- 入力モード切り替え用関数 ---
def process_input(input_mode, input_text, extension_percentage, temperature, trim_whitespace, style):
if input_mode == "HTML入力":
return render_fullpage_screenshot(input_text, extension_percentage, trim_whitespace)
else:
return text_to_screenshot_parallel(input_text, extension_percentage, temperature, trim_whitespace, style)
# --- FastAPI Setup ---
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 静的ファイルのサービング設定
gradio_dir = os.path.dirname(gr.__file__)
static_dir = os.path.join(gradio_dir, "templates", "frontend", "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
app_dir = os.path.join(gradio_dir, "templates", "frontend", "_app")
if os.path.exists(app_dir):
app.mount("/_app", StaticFiles(directory=app_dir), name="_app")
assets_dir = os.path.join(gradio_dir, "templates", "frontend", "assets")
if os.path.exists(assets_dir):
app.mount("/assets", StaticFiles(directory=assets_dir), name="assets")
cdn_dir = os.path.join(gradio_dir, "templates", "cdn")
if os.path.exists(cdn_dir):
app.mount("/cdn", StaticFiles(directory=cdn_dir), name="cdn")
# --- API Endpoint for HTML→Screenshot ---
@app.post("/api/screenshot",
response_class=StreamingResponse,
tags=["Screenshot"],
summary="Render HTML to Full Page Screenshot",
description="Takes HTML code and an optional vertical extension percentage, renders it using a headless browser, and returns the full-page screenshot as a PNG image.")
async def api_render_screenshot(request: ScreenshotRequest):
try:
logger.info(f"API request received. Extension: {request.extension_percentage}%")
pil_image = render_fullpage_screenshot(
request.html_code,
request.extension_percentage,
request.trim_whitespace
)
if pil_image.size == (1, 1):
logger.error("Screenshot generation failed, returning 1x1 error image.")
img_byte_arr = BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
logger.info("Returning screenshot as PNG stream.")
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
logger.error(f"API Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
# --- API Endpoint for Text→Infographic Screenshot ---
@app.post("/api/text-to-screenshot",
response_class=StreamingResponse,
tags=["Screenshot", "Gemini"],
summary="テキストからインフォグラフィックを生成",
description="テキストをGemini APIを使ってHTMLインフォグラフィックに変換し、スクリーンショットとして返します。")
async def api_text_to_screenshot(request: GeminiRequest):
try:
logger.info(
f"テキスト→スクリーンショットAPIリクエスト受信。"
f"テキスト長さ: {len(request.text)}, 拡張率: {request.extension_percentage}%, "
f"温度: {request.temperature}, スタイル: {request.style}"
)
pil_image = text_to_screenshot_parallel(
request.text,
request.extension_percentage,
request.temperature,
request.trim_whitespace,
request.style
)
if pil_image.size == (1, 1):
logger.error("スクリーンショット生成に失敗しました。1x1エラー画像を返します。")
img_byte_arr = BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
logger.info("スクリーンショットをPNGストリームとして返します。")
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
logger.error(f"API Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
# --- Gradio Interface Definition ---
with gr.Blocks(title="Full Page Screenshot (テキスト変換対応)", theme=gr.themes.Base()) as iface:
gr.Markdown("# HTMLビューア & テキスト→インフォグラフィック変換")
gr.Markdown("HTMLコードをレンダリングするか、テキストをGemini APIでインフォグラフィックに変換して画像として取得します。")
gr.Markdown("**パフォーマンス向上版**: 並列処理と最適化により処理時間を短縮しています")
with gr.Row():
input_mode = gr.Radio(
["HTML入力", "テキスト入力"],
label="入力モード",
value="HTML入力"
)
input_text = gr.Textbox(
lines=15,
label="入力",
placeholder="HTMLコードまたはテキストを入力してください。入力モードに応じて処理されます。"
)
with gr.Row():
with gr.Column(scale=1):
style_dropdown = gr.Dropdown(
choices=["standard", "cute", "resort", "cool", "dental", "school", "KOKUGO"],
value="standard",
label="デザインスタイル",
info="テキスト→HTML変換時のデザインテーマを選択します",
visible=False
)
with gr.Column(scale=2):
extension_percentage = gr.Slider(
minimum=0,
maximum=30,
step=1.0,
value=10,
label="上下高さ拡張率(%)"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.5,
label="生成時の温度(低い=一貫性高、高い=創造性高)",
visible=False
)
trim_whitespace = gr.Checkbox(
label="余白を自動トリミング",
value=True,
info="生成される画像から余分な空白領域を自動的に削除します"
)
submit_btn = gr.Button("生成")
output_image = gr.Image(type="pil", label="ページ全体のスクリーンショット")
def update_controls_visibility(mode):
is_text_mode = (mode == "テキスト入力")
return [
gr.update(visible=is_text_mode), # temperature
gr.update(visible=is_text_mode), # style_dropdown
]
input_mode.change(
fn=update_controls_visibility,
inputs=input_mode,
outputs=[temperature, style_dropdown]
)
submit_btn.click(
fn=process_input,
inputs=[input_mode, input_text, extension_percentage, temperature, trim_whitespace, style_dropdown],
outputs=output_image
)
gemini_model = os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
gr.Markdown(f"""
## APIエンドポイント
- `/api/screenshot` - HTMLコードからスクリーンショットを生成
- `/api/text-to-screenshot` - テキストからインフォグラフィックスクリーンショットを生成
## 設定情報
- 使用モデル: {gemini_model} (環境変数 GEMINI_MODEL で変更可能)
- 対応スタイル: standard, cute, resort, cool, dental, school, KOKUGO
- WebDriverプール最大数: {driver_pool.max_drivers} (環境変数 MAX_WEBDRIVERS で変更可能)
""")
# --- Mount Gradio App onto FastAPI ---
app = gr.mount_gradio_app(app, iface, path="/")
# --- ローカル開発用 Uvicorn 起動 ---
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server for local development...")
uvicorn.run(app, host="0.0.0.0", port=7860)
# アプリケーション終了時にWebDriverプールをクリーンアップ
import atexit
atexit.register(driver_pool.close_all)