himipo commited on
Commit
afcbb41
·
1 Parent(s): 907530a
Files changed (1) hide show
  1. app.py +61 -9
app.py CHANGED
@@ -1,6 +1,9 @@
1
  # app.py
2
  from collections import Counter
 
3
  from typing import Tuple, Dict, Any, List
 
 
4
 
5
  import gradio as gr
6
  import numpy as np
@@ -36,6 +39,31 @@ except Exception:
36
  pass
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def pil_to_np(img: Image.Image) -> np.ndarray:
40
  return np.array(img.convert("RGB"))
41
 
@@ -104,6 +132,7 @@ def summarize_detections(detections: List[Detection]) -> List[List[Any]]:
104
  def inference_pipeline(
105
  image: Image.Image,
106
  score_thresh: float = 0.8,
 
107
  ) -> Tuple[Image.Image, List[List[Any]]]:
108
  """Gradio から呼ばれるメイン処理"""
109
  if image is None:
@@ -125,6 +154,15 @@ def inference_pipeline(
125
  # DEIMv2 推論
126
  detections = run_inference(img_np, score_thresh=score_thresh)
127
 
 
 
 
 
 
 
 
 
 
128
  # 描画
129
  vis_pil = draw_detections(img_pil.copy(), detections)
130
 
@@ -141,9 +179,10 @@ def inference_pipeline(
141
  def gpu_inference(
142
  image: Image.Image,
143
  score_thresh: float = 0.8, # UIのデフォルト値と統一
 
144
  ):
145
  """Spaces ZeroGPU が検出できるようにデコレータ付きの推論関数を用意"""
146
- return inference_pipeline(image, score_thresh)
147
 
148
 
149
  # =========================
@@ -160,6 +199,9 @@ with gr.Blocks(title="DEIMv2 Floorplan Symbol Detection") as demo:
160
  """
161
  )
162
 
 
 
 
163
  with gr.Row():
164
  # 左: 入力
165
  with gr.Column(scale=1):
@@ -168,13 +210,23 @@ with gr.Blocks(title="DEIMv2 Floorplan Symbol Detection") as demo:
168
  type="pil",
169
  image_mode="RGB",
170
  )
171
- score_thresh = gr.Slider(
172
- minimum=0.0,
173
- maximum=1.0,
174
- value=0.8,
175
- step=0.05,
176
- label="スコア閾値",
177
- )
 
 
 
 
 
 
 
 
 
 
178
  run_button = gr.Button("検出を実行", variant="primary")
179
 
180
  # 中央: 出力画像
@@ -195,7 +247,7 @@ with gr.Blocks(title="DEIMv2 Floorplan Symbol Detection") as demo:
195
  # ボタンの動作
196
  run_button.click(
197
  fn=gpu_inference,
198
- inputs=[input_image, score_thresh],
199
  outputs=[output_image, summary_dataframe],
200
  )
201
 
 
1
  # app.py
2
  from collections import Counter
3
+ from functools import lru_cache
4
  from typing import Tuple, Dict, Any, List
5
+ import os
6
+ import yaml
7
 
8
  import gradio as gr
9
  import numpy as np
 
39
  pass
40
 
41
 
42
+ @lru_cache(maxsize=1)
43
+ def load_class_names() -> List[str]:
44
+ """
45
+ 設定ファイルからクラスリストを読み込む
46
+ """
47
+ config_path = "configs/deimv2_floorplan.yaml"
48
+ try:
49
+ with open(config_path, 'r', encoding='utf-8') as f:
50
+ config = yaml.safe_load(f)
51
+ # Modelセクションからclass_namesを取得
52
+ if 'Model' in config and 'class_names' in config['Model']:
53
+ return config['Model']['class_names']
54
+ else:
55
+ # フォールバック: デフォルトのクラスリスト
56
+ return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2",
57
+ "door1", "door2", "bathtub1", "konro1", "sink1", "toilet1",
58
+ "kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"]
59
+ except Exception as e:
60
+ # エラー時はデフォルトのクラスリストを返す
61
+ print(f"Warning: Failed to load class names from config: {e}")
62
+ return ["kanki", "kanki_shikaku", "kanki_regisuta", "window1", "window2",
63
+ "door1", "door2", "bathtub1", "konro1", "sink1", "toilet1",
64
+ "kasaikeihou1", "kasaikeihou2", "houi1", "houi2", "houi3"]
65
+
66
+
67
  def pil_to_np(img: Image.Image) -> np.ndarray:
68
  return np.array(img.convert("RGB"))
69
 
 
132
  def inference_pipeline(
133
  image: Image.Image,
134
  score_thresh: float = 0.8,
135
+ selected_classes: List[str] = None,
136
  ) -> Tuple[Image.Image, List[List[Any]]]:
137
  """Gradio から呼ばれるメイン処理"""
138
  if image is None:
 
154
  # DEIMv2 推論
155
  detections = run_inference(img_np, score_thresh=score_thresh)
156
 
157
+ # クラスフィルタリング: 選択されたクラスのみを残す
158
+ if selected_classes is not None and len(selected_classes) > 0:
159
+ # 選択されたクラスリストに含まれる検出結果のみをフィルタリング
160
+ filtered_detections = [
161
+ det for det in detections
162
+ if det[4] in selected_classes # det[4]はlabel_name
163
+ ]
164
+ detections = filtered_detections
165
+
166
  # 描画
167
  vis_pil = draw_detections(img_pil.copy(), detections)
168
 
 
179
  def gpu_inference(
180
  image: Image.Image,
181
  score_thresh: float = 0.8, # UIのデフォルト値と統一
182
+ selected_classes: List[str] = None,
183
  ):
184
  """Spaces ZeroGPU が検出できるようにデコレータ付きの推論関数を用意"""
185
+ return inference_pipeline(image, score_thresh, selected_classes)
186
 
187
 
188
  # =========================
 
199
  """
200
  )
201
 
202
+ # クラスリストを読み込む
203
+ class_names = load_class_names()
204
+
205
  with gr.Row():
206
  # 左: 入力
207
  with gr.Column(scale=1):
 
210
  type="pil",
211
  image_mode="RGB",
212
  )
213
+
214
+ # 詳細設定タブ(デフォルトは閉じた状態)
215
+ with gr.Accordion("詳細設定", open=False):
216
+ score_thresh = gr.Slider(
217
+ minimum=0.0,
218
+ maximum=1.0,
219
+ value=0.8,
220
+ step=0.05,
221
+ label="スコア閾値",
222
+ )
223
+ selected_classes = gr.CheckboxGroup(
224
+ choices=class_names,
225
+ value=class_names, # デフォルトで全クラスを選択
226
+ label="検出するクラス",
227
+ info="選択したクラスの検出結果のみが表示されます",
228
+ )
229
+
230
  run_button = gr.Button("検出を実行", variant="primary")
231
 
232
  # 中央: 出力画像
 
247
  # ボタンの動作
248
  run_button.click(
249
  fn=gpu_inference,
250
+ inputs=[input_image, score_thresh, selected_classes],
251
  outputs=[output_image, summary_dataframe],
252
  )
253