gpu_symbol / app.py
himipo's picture
first
075afae
# app.py
from collections import Counter
from functools import lru_cache
from typing import Tuple, Dict, Any, List
import os
import yaml
import gradio as gr
import numpy as np
import spaces
from PIL import Image, ImageDraw, ImageFont
from detection import run_inference, Detection
# Gradio 5.0.x と JSON コンポーネントの組み合わせで
# /info 生成時に json_schema_to_python_type が bool を dict とみなして落ちる
# 既知バグがあるため、bool を安全に処理するようにパッチを当てる。
try:
from gradio_client import utils as grc_utils
_orig_json_schema_to_python_type = grc_utils._json_schema_to_python_type # type: ignore[attr-defined]
_orig_json_schema_to_python_type_public = grc_utils.json_schema_to_python_type
def _json_schema_to_python_type_safe(schema, defs=None): # type: ignore[override]
if isinstance(schema, bool):
return "Any"
return _orig_json_schema_to_python_type(schema, defs)
grc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe # type: ignore[attr-defined]
def _json_schema_to_python_type_safe_public(schema):
if isinstance(schema, bool):
return "Any"
return _orig_json_schema_to_python_type_public(schema)
grc_utils.json_schema_to_python_type = _json_schema_to_python_type_safe_public
except Exception:
# パッチが失敗してもアプリ起動は継続する
pass
@lru_cache(maxsize=1)
def load_class_names() -> List[str]:
"""
設定ファイルからクラスリストを読み込む
"""
config_path = "configs/deimv2_floorplan.yaml"
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Modelセクションからclass_namesを取得
if 'Model' in config and 'class_names' in config['Model']:
return config['Model']['class_names']
else:
# フォールバック: デフォルトのクラスリスト
return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2",
"door1", "door2", "bathtub1", "konro1", "sink1", "toilet1",
"kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"]
except Exception as e:
# エラー時はデフォルトのクラスリストを返す
print(f"Warning: Failed to load class names from config: {e}")
return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2",
"door1", "door2", "bathtub1", "konro1", "sink1", "toilet1",
"kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"]
def pil_to_np(img: Image.Image) -> np.ndarray:
return np.array(img.convert("RGB"))
def draw_detections(
image_pil: Image.Image,
detections: List[Detection],
) -> Image.Image:
"""検出結果を図面に重ねて描画"""
draw = ImageDraw.Draw(image_pil)
# クラスごとの色マッピング
color_map = {
"kanki": (255, 0, 0), # 赤
"door1": (0, 0, 255), # 青
"door2": (255, 255, 0), # 黄
}
default_color = (0, 255, 0) # デフォルト色(緑)
try:
font = ImageFont.truetype("DejaVuSans.ttf", 24)
except Exception:
font = ImageFont.load_default()
for (x1, y1, x2, y2, label, score) in detections:
# ラベルに応じた色を取得
color = color_map.get(label, default_color)
# bbox
draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3)
# ラベル+スコア
text = f"{label} {score:.2f}"
# textsizeは非推奨のため、textbboxを使用(互換性のためフォールバックあり)
try:
bbox = draw.textbbox((0, 0), text, font=font)
tw = bbox[2] - bbox[0]
th = bbox[3] - bbox[1]
except AttributeError:
# 古いPillowバージョン用のフォールバック
tw, th = draw.textsize(text, font=font)
draw.rectangle(
[(x1, y1 - th - 2), (x1 + tw + 2, y1)],
fill=color,
)
draw.text((x1 + 1, y1 - th - 2), text, fill=(0, 0, 0), font=font)
return image_pil
def summarize_detections(detections: List[Detection]) -> List[List[Any]]:
"""ラベルごとの個数を集計してDataframe用のデータ形式にする"""
labels = [d[4] for d in detections]
counter = Counter(labels)
# ヘッダー行
data = [["記号名称", "個数"]]
# データ行(個数の降順、ラベル名の昇順でソート)
for label, count in sorted(counter.items(), key=lambda x: (-x[1], x[0])):
data.append([label, count])
return data
def inference_pipeline(
image: Image.Image,
score_thresh: float = 0.8,
selected_classes: List[str] = None,
) -> Tuple[Image.Image, List[List[Any]]]:
"""Gradio から呼ばれるメイン処理"""
if image is None:
raise gr.Error("PNG形式の図面をアップロードしてください。")
try:
# 画像が文字列(ファイルパス)の場合はPIL Imageに変換
if isinstance(image, str):
try:
img_pil = Image.open(image).convert("RGB")
except Exception as e:
raise gr.Error(f"画像ファイルの読み込みに失敗しました: {str(e)}")
else:
# 既にPIL Imageオブジェクトの場合
img_pil = image.convert("RGB")
img_np = pil_to_np(img_pil)
# DEIMv2 推論
detections = run_inference(img_np, score_thresh=score_thresh)
# クラスフィルタリング: 選択されたクラスのみを残す
if selected_classes is not None and len(selected_classes) > 0:
# 選択されたクラスリストに含まれる検出結果のみをフィルタリング
filtered_detections = [
det for det in detections
if det[4] in selected_classes # det[4]はlabel_name
]
detections = filtered_detections
# 描画
vis_pil = draw_detections(img_pil.copy(), detections)
# 集計
summary = summarize_detections(detections)
return vis_pil, summary
except Exception as e:
error_msg = f"エラーが発生しました: {str(e)}"
raise gr.Error(error_msg)
@spaces.GPU
def gpu_inference(
image: Image.Image,
score_thresh: float = 0.9, # UIのデフォルト値と統一
selected_classes: List[str] = None,
):
"""Spaces ZeroGPU が検出できるようにデコレータ付きの推論関数を用意"""
return inference_pipeline(image, score_thresh, selected_classes)
# =========================
# Gradio UI
# =========================
with gr.Blocks(title="DEIMv2 Floorplan Symbol Detection") as demo:
gr.Markdown(
"""
# 図面記号検出デモ(by AItech)
1. 左側に **PNG図面** をアップロード
2. 「検出を実行」を押す
3. 中央に **検出結果付き図面**、右側に **記号名称+個数** が表示されます。
"""
)
# クラスリストを読み込む
class_names = load_class_names()
with gr.Row():
# 左: 入力
with gr.Column(scale=1):
input_image = gr.Image(
label="入力図面 (PNG)",
type="pil",
image_mode="RGB",
)
# 詳細設定タブ(デフォルトは閉じた状態)
with gr.Accordion("詳細設定", open=False):
score_thresh = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="スコア閾値",
)
selected_classes = gr.CheckboxGroup(
choices=class_names,
value=["door1"], # デフォルトでdoor1のみ選択
label="検出するクラス",
info="選択したクラスの検出結果のみが表示されます",
)
run_button = gr.Button("検出を実行", variant="primary")
# 中央: 出力画像
with gr.Column(scale=2):
output_image = gr.Image(
label="検出結果付き図面",
type="pil",
)
# 右: サイドバー(記号名称+個数)
with gr.Column(scale=1):
summary_dataframe = gr.Dataframe(
label="検出サマリ (記号名称と個数)",
headers=["記号名称", "個数"],
interactive=False,
)
# ボタンの動作
run_button.click(
fn=gpu_inference,
inputs=[input_image, score_thresh, selected_classes],
outputs=[output_image, summary_dataframe],
)
# Gradio 5では、Spaces上ではdemoオブジェクトを直接エクスポートするだけで動作します
# ローカルテスト時のみdemo.launch()を呼び出します
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
# Spaces上では、demoオブジェクトを直接エクスポートします
# Gradio 5は自動的にdemoオブジェクトを検出して起動します