Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| from collections import Counter | |
| from functools import lru_cache | |
| from typing import Tuple, Dict, Any, List | |
| import os | |
| import yaml | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from PIL import Image, ImageDraw, ImageFont | |
| from detection import run_inference, Detection | |
| # Gradio 5.0.x と JSON コンポーネントの組み合わせで | |
| # /info 生成時に json_schema_to_python_type が bool を dict とみなして落ちる | |
| # 既知バグがあるため、bool を安全に処理するようにパッチを当てる。 | |
| try: | |
| from gradio_client import utils as grc_utils | |
| _orig_json_schema_to_python_type = grc_utils._json_schema_to_python_type # type: ignore[attr-defined] | |
| _orig_json_schema_to_python_type_public = grc_utils.json_schema_to_python_type | |
| def _json_schema_to_python_type_safe(schema, defs=None): # type: ignore[override] | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_json_schema_to_python_type(schema, defs) | |
| grc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe # type: ignore[attr-defined] | |
| def _json_schema_to_python_type_safe_public(schema): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_json_schema_to_python_type_public(schema) | |
| grc_utils.json_schema_to_python_type = _json_schema_to_python_type_safe_public | |
| except Exception: | |
| # パッチが失敗してもアプリ起動は継続する | |
| pass | |
| def load_class_names() -> List[str]: | |
| """ | |
| 設定ファイルからクラスリストを読み込む | |
| """ | |
| config_path = "configs/deimv2_floorplan.yaml" | |
| try: | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| # Modelセクションからclass_namesを取得 | |
| if 'Model' in config and 'class_names' in config['Model']: | |
| return config['Model']['class_names'] | |
| else: | |
| # フォールバック: デフォルトのクラスリスト | |
| return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2", | |
| "door1", "door2", "bathtub1", "konro1", "sink1", "toilet1", | |
| "kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"] | |
| except Exception as e: | |
| # エラー時はデフォルトのクラスリストを返す | |
| print(f"Warning: Failed to load class names from config: {e}") | |
| return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2", | |
| "door1", "door2", "bathtub1", "konro1", "sink1", "toilet1", | |
| "kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"] | |
| def pil_to_np(img: Image.Image) -> np.ndarray: | |
| return np.array(img.convert("RGB")) | |
| def draw_detections( | |
| image_pil: Image.Image, | |
| detections: List[Detection], | |
| ) -> Image.Image: | |
| """検出結果を図面に重ねて描画""" | |
| draw = ImageDraw.Draw(image_pil) | |
| # クラスごとの色マッピング | |
| color_map = { | |
| "kanki": (255, 0, 0), # 赤 | |
| "door1": (0, 0, 255), # 青 | |
| "door2": (255, 255, 0), # 黄 | |
| } | |
| default_color = (0, 255, 0) # デフォルト色(緑) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans.ttf", 24) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for (x1, y1, x2, y2, label, score) in detections: | |
| # ラベルに応じた色を取得 | |
| color = color_map.get(label, default_color) | |
| # bbox | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=color, width=3) | |
| # ラベル+スコア | |
| text = f"{label} {score:.2f}" | |
| # textsizeは非推奨のため、textbboxを使用(互換性のためフォールバックあり) | |
| try: | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| tw = bbox[2] - bbox[0] | |
| th = bbox[3] - bbox[1] | |
| except AttributeError: | |
| # 古いPillowバージョン用のフォールバック | |
| tw, th = draw.textsize(text, font=font) | |
| draw.rectangle( | |
| [(x1, y1 - th - 2), (x1 + tw + 2, y1)], | |
| fill=color, | |
| ) | |
| draw.text((x1 + 1, y1 - th - 2), text, fill=(0, 0, 0), font=font) | |
| return image_pil | |
| def summarize_detections(detections: List[Detection]) -> List[List[Any]]: | |
| """ラベルごとの個数を集計してDataframe用のデータ形式にする""" | |
| labels = [d[4] for d in detections] | |
| counter = Counter(labels) | |
| # ヘッダー行 | |
| data = [["記号名称", "個数"]] | |
| # データ行(個数の降順、ラベル名の昇順でソート) | |
| for label, count in sorted(counter.items(), key=lambda x: (-x[1], x[0])): | |
| data.append([label, count]) | |
| return data | |
| def inference_pipeline( | |
| image: Image.Image, | |
| score_thresh: float = 0.8, | |
| selected_classes: List[str] = None, | |
| ) -> Tuple[Image.Image, List[List[Any]]]: | |
| """Gradio から呼ばれるメイン処理""" | |
| if image is None: | |
| raise gr.Error("PNG形式の図面をアップロードしてください。") | |
| try: | |
| # 画像が文字列(ファイルパス)の場合はPIL Imageに変換 | |
| if isinstance(image, str): | |
| try: | |
| img_pil = Image.open(image).convert("RGB") | |
| except Exception as e: | |
| raise gr.Error(f"画像ファイルの読み込みに失敗しました: {str(e)}") | |
| else: | |
| # 既にPIL Imageオブジェクトの場合 | |
| img_pil = image.convert("RGB") | |
| img_np = pil_to_np(img_pil) | |
| # DEIMv2 推論 | |
| detections = run_inference(img_np, score_thresh=score_thresh) | |
| # クラスフィルタリング: 選択されたクラスのみを残す | |
| if selected_classes is not None and len(selected_classes) > 0: | |
| # 選択されたクラスリストに含まれる検出結果のみをフィルタリング | |
| filtered_detections = [ | |
| det for det in detections | |
| if det[4] in selected_classes # det[4]はlabel_name | |
| ] | |
| detections = filtered_detections | |
| # 描画 | |
| vis_pil = draw_detections(img_pil.copy(), detections) | |
| # 集計 | |
| summary = summarize_detections(detections) | |
| return vis_pil, summary | |
| except Exception as e: | |
| error_msg = f"エラーが発生しました: {str(e)}" | |
| raise gr.Error(error_msg) | |
| def gpu_inference( | |
| image: Image.Image, | |
| score_thresh: float = 0.9, # UIのデフォルト値と統一 | |
| selected_classes: List[str] = None, | |
| ): | |
| """Spaces ZeroGPU が検出できるようにデコレータ付きの推論関数を用意""" | |
| return inference_pipeline(image, score_thresh, selected_classes) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks(title="DEIMv2 Floorplan Symbol Detection") as demo: | |
| gr.Markdown( | |
| """ | |
| # 図面記号検出デモ(by AItech) | |
| 1. 左側に **PNG図面** をアップロード | |
| 2. 「検出を実行」を押す | |
| 3. 中央に **検出結果付き図面**、右側に **記号名称+個数** が表示されます。 | |
| """ | |
| ) | |
| # クラスリストを読み込む | |
| class_names = load_class_names() | |
| with gr.Row(): | |
| # 左: 入力 | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="入力図面 (PNG)", | |
| type="pil", | |
| image_mode="RGB", | |
| ) | |
| # 詳細設定タブ(デフォルトは閉じた状態) | |
| with gr.Accordion("詳細設定", open=False): | |
| score_thresh = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="スコア閾値", | |
| ) | |
| selected_classes = gr.CheckboxGroup( | |
| choices=class_names, | |
| value=["door1"], # デフォルトでdoor1のみ選択 | |
| label="検出するクラス", | |
| info="選択したクラスの検出結果のみが表示されます", | |
| ) | |
| run_button = gr.Button("検出を実行", variant="primary") | |
| # 中央: 出力画像 | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| label="検出結果付き図面", | |
| type="pil", | |
| ) | |
| # 右: サイドバー(記号名称+個数) | |
| with gr.Column(scale=1): | |
| summary_dataframe = gr.Dataframe( | |
| label="検出サマリ (記号名称と個数)", | |
| headers=["記号名称", "個数"], | |
| interactive=False, | |
| ) | |
| # ボタンの動作 | |
| run_button.click( | |
| fn=gpu_inference, | |
| inputs=[input_image, score_thresh, selected_classes], | |
| outputs=[output_image, summary_dataframe], | |
| ) | |
| # Gradio 5では、Spaces上ではdemoオブジェクトを直接エクスポートするだけで動作します | |
| # ローカルテスト時のみdemo.launch()を呼び出します | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| # Spaces上では、demoオブジェクトを直接エクスポートします | |
| # Gradio 5は自動的にdemoオブジェクトを検出して起動します | |