oggata commited on
Commit
9150cca
·
verified ·
1 Parent(s): 543ca9a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +421 -0
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
6
+ from PIL import Image
7
+ import json
8
+ from scipy import ndimage
9
+ from scipy.ndimage import binary_opening, binary_closing, sobel, binary_dilation, median_filter
10
+ import zipfile
11
+ import io
12
+ import os
13
+
14
+ # グローバル変数
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ processor = None
17
+ model = None
18
+
19
+ # カテゴリ定義
20
+ ADE20K_TO_CITY_MAPPING = {
21
+ 'road': 'road', 'street': 'road', 'path': 'road', 'sidewalk': 'road',
22
+ 'building': 'building_c', 'house': 'building_a', 'skyscraper': 'building_e',
23
+ 'highrise': 'building_e', 'tower': 'building_e',
24
+ 'office': 'building_d', 'shop': 'building_b', 'store': 'building_b',
25
+ 'apartment': 'building_c', 'hotel': 'building_d',
26
+ 'tree': 'forest', 'plant': 'forest', 'palm': 'forest',
27
+ 'grass': 'park', 'field': 'park', 'flower': 'park',
28
+ 'water': 'water', 'sea': 'water', 'river': 'water', 'lake': 'water',
29
+ 'earth': 'bare_land', 'sand': 'bare_land', 'ground': 'bare_land',
30
+ 'parking lot': 'infrastructure', 'stadium': 'building_d',
31
+ }
32
+
33
+ CITY_CATEGORIES = {
34
+ 'road': {'label': '道路', 'color': (128, 64, 128), 'height': 0, 'semantic_id': 0},
35
+ 'forest': {'label': '森林', 'color': (34, 139, 34), 'height': 1.5, 'semantic_id': 1},
36
+ 'park': {'label': '公園/緑地', 'color': (144, 238, 144), 'height': 0.5, 'semantic_id': 2},
37
+ 'water': {'label': '水域', 'color': (30, 144, 255), 'height': 0, 'semantic_id': 3},
38
+ 'building_a': {'label': '建物A(小)', 'color': (255, 200, 150), 'height': 0.6, 'semantic_id': 4},
39
+ 'building_b': {'label': '建物B(中小)', 'color': (255, 160, 122), 'height': 1.0, 'semantic_id': 5},
40
+ 'building_c': {'label': '建物C(中)', 'color': (240, 120, 90), 'height': 1.5, 'semantic_id': 6},
41
+ 'building_d': {'label': '建物D(中大)', 'color': (220, 80, 60), 'height': 2.2, 'semantic_id': 7},
42
+ 'building_e': {'label': '建物E(大)', 'color': (200, 40, 40), 'height': 3.0, 'semantic_id': 8},
43
+ 'bare_land': {'label': '空き地', 'color': (210, 180, 140), 'height': 0.1, 'semantic_id': 9},
44
+ 'infrastructure': {'label': 'インフラ', 'color': (100, 100, 100), 'height': 0.8, 'semantic_id': 10},
45
+ 'other': {'label': 'その他/境界', 'color': (80, 80, 80), 'height': 0, 'semantic_id': 11}
46
+ }
47
+
48
+ def load_model():
49
+ global processor, model
50
+ if processor is None or model is None:
51
+ processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640")
52
+ model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640").to(device)
53
+ return processor, model
54
+
55
+ def map_ade20k_to_city(class_id, id2label):
56
+ if class_id not in id2label:
57
+ return 'other'
58
+ class_name = id2label[class_id].lower()
59
+
60
+ for ade_name, city_cat in ADE20K_TO_CITY_MAPPING.items():
61
+ if ade_name in class_name:
62
+ return city_cat
63
+
64
+ if any(w in class_name for w in ['skyscraper', 'highrise', 'tower']):
65
+ return 'building_e'
66
+ elif any(w in class_name for w in ['office', 'hotel', 'commercial', 'stadium']):
67
+ return 'building_d'
68
+ elif any(w in class_name for w in ['building', 'apartment']):
69
+ return 'building_c'
70
+ elif any(w in class_name for w in ['shop', 'store', 'market']):
71
+ return 'building_b'
72
+ elif any(w in class_name for w in ['house', 'home', 'shed', 'hut']):
73
+ return 'building_a'
74
+
75
+ return 'other'
76
+
77
+ def segment_with_tiling(image, processor, model, tile_size=320, overlap=64, use_tiling=True):
78
+ h, w = image.shape[:2]
79
+
80
+ if not use_tiling or (h <= tile_size and w <= tile_size):
81
+ pil_image = Image.fromarray(image)
82
+ inputs = processor(images=pil_image, return_tensors="pt").to(device)
83
+
84
+ with torch.no_grad():
85
+ outputs = model(**inputs)
86
+ logits = outputs.logits
87
+
88
+ upsampled_logits = torch.nn.functional.interpolate(
89
+ logits, size=image.shape[:2], mode="bilinear", align_corners=False
90
+ )
91
+ return upsampled_logits.argmax(dim=1)[0].cpu().numpy()
92
+
93
+ stride = tile_size - overlap
94
+ num_tiles_h = (h - overlap) // stride + (1 if (h - overlap) % stride > 0 else 0)
95
+ num_tiles_w = (w - overlap) // stride + (1 if (w - overlap) % stride > 0 else 0)
96
+
97
+ votes = np.zeros((h, w, 150), dtype=np.float32)
98
+ counts = np.zeros((h, w), dtype=np.float32)
99
+
100
+ for i in range(num_tiles_h):
101
+ for j in range(num_tiles_w):
102
+ y_start = i * stride
103
+ x_start = j * stride
104
+ y_end = min(y_start + tile_size, h)
105
+ x_end = min(x_start + tile_size, w)
106
+
107
+ tile = image[y_start:y_end, x_start:x_end]
108
+ pil_tile = Image.fromarray(tile)
109
+ inputs = processor(images=pil_tile, return_tensors="pt").to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model(**inputs)
113
+ logits = outputs.logits
114
+
115
+ upsampled = torch.nn.functional.interpolate(
116
+ logits, size=tile.shape[:2], mode="bilinear", align_corners=False
117
+ )
118
+ probs = torch.nn.functional.softmax(upsampled, dim=1)[0].cpu().numpy()
119
+
120
+ votes[y_start:y_end, x_start:x_end] += probs.transpose(1, 2, 0)
121
+ counts[y_start:y_end, x_start:x_end] += 1
122
+
123
+ counts = np.maximum(counts, 1)
124
+ final_votes = votes / counts[:, :, np.newaxis]
125
+ return final_votes.argmax(axis=2)
126
+
127
+ def create_colored_segmentation(seg_map):
128
+ h, w = seg_map.shape
129
+ colored = np.zeros((h, w, 3), dtype=np.uint8)
130
+ for cat_name, cat_info in CITY_CATEGORIES.items():
131
+ mask = seg_map == cat_info['semantic_id']
132
+ colored[mask] = cat_info['color']
133
+ return colored
134
+
135
+ def detect_boundaries(segmentation_map, thickness=5):
136
+ edges_h = np.abs(sobel(segmentation_map.astype(float), axis=0)) > 0
137
+ edges_v = np.abs(sobel(segmentation_map.astype(float), axis=1)) > 0
138
+ boundaries = edges_h | edges_v
139
+
140
+ if thickness > 1:
141
+ kernel = np.ones((thickness, thickness), dtype=bool)
142
+ boundaries = binary_dilation(boundaries, structure=kernel)
143
+
144
+ other_id = CITY_CATEGORIES['other']['semantic_id']
145
+ segmentation_map[boundaries] = other_id
146
+ return segmentation_map, boundaries
147
+
148
+ def create_3d_mesh(segments, image, resolution=2):
149
+ height, width = image.shape[:2]
150
+ meshes_data = []
151
+
152
+ for idx, segment in enumerate(segments):
153
+ segmentation = segment['segmentation']
154
+ bbox = segment['bbox']
155
+ x, y, w, h = bbox
156
+
157
+ if w < 3 or h < 3:
158
+ continue
159
+
160
+ segment_area = segmentation[y:y+h, x:x+w]
161
+ segment_image = image[y:y+h, x:x+w]
162
+
163
+ if not segment_area.any():
164
+ continue
165
+
166
+ vertices = []
167
+ faces = []
168
+ colors = []
169
+ step = resolution
170
+ building_height = segment['height'] * 0.5
171
+
172
+ for sy in range(0, segment_area.shape[0] - step, step):
173
+ for sx in range(0, segment_area.shape[1] - step, step):
174
+ if not segment_area[sy, sx]:
175
+ continue
176
+
177
+ world_x = (x + sx - width/2) * 0.1
178
+ world_z = (y + sy - height/2) * 0.1
179
+
180
+ base_idx = len(vertices)
181
+ vertices.extend([
182
+ [float(world_x), float(building_height), float(world_z)],
183
+ [float(world_x + step*0.1), float(building_height), float(world_z)],
184
+ [float(world_x + step*0.1), float(building_height), float(world_z + step*0.1)],
185
+ [float(world_x), float(building_height), float(world_z + step*0.1)]
186
+ ])
187
+
188
+ vertices.extend([
189
+ [float(world_x), 0.0, float(world_z)],
190
+ [float(world_x + step*0.1), 0.0, float(world_z)],
191
+ [float(world_x + step*0.1), 0.0, float(world_z + step*0.1)],
192
+ [float(world_x), 0.0, float(world_z + step*0.1)]
193
+ ])
194
+
195
+ if sy < segment_image.shape[0] and sx < segment_image.shape[1]:
196
+ color = segment_image[sy, sx] / 255.0
197
+ color_list = [float(color[0]), float(color[1]), float(color[2])]
198
+ else:
199
+ color_list = [0.5, 0.5, 0.5]
200
+
201
+ wall_color = [c * 0.7 for c in color_list]
202
+ colors.extend([color_list] * 4 + [wall_color] * 4)
203
+
204
+ faces.extend([
205
+ [base_idx, base_idx+1, base_idx+2],
206
+ [base_idx, base_idx+2, base_idx+3]
207
+ ])
208
+
209
+ if len(vertices) > 0:
210
+ meshes_data.append({
211
+ 'id': int(idx),
212
+ 'category': str(segment['category']),
213
+ 'label': str(segment['label']),
214
+ 'semantic_id': int(segment['semantic_id']),
215
+ 'vertices': vertices,
216
+ 'faces': faces,
217
+ 'colors': colors,
218
+ 'center': [
219
+ float((x + w/2 - width/2) * 0.1),
220
+ float(segment['height'] * 0.5),
221
+ float((y + h/2 - height/2) * 0.1)
222
+ ],
223
+ 'bbox': [int(x), int(y), int(w), int(h)],
224
+ 'area': float(segment['area']),
225
+ 'height': float(segment['height'])
226
+ })
227
+
228
+ return meshes_data
229
+
230
+ def process_image(image, max_size, tile_size, tile_overlap, min_area, mesh_res,
231
+ apply_morphology, morph_kernel, detect_bound, bound_thickness,
232
+ apply_smoothing, smooth_iter, use_tiling):
233
+
234
+ # モデルロード
235
+ processor, model = load_model()
236
+
237
+ # 画像前処理
238
+ if isinstance(image, str):
239
+ original_image = cv2.imread(image)
240
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
241
+ else:
242
+ original_image = np.array(image)
243
+
244
+ original_height, original_width = original_image.shape[:2]
245
+
246
+ if max(original_height, original_width) > max_size:
247
+ scale_factor = max_size / max(original_height, original_width)
248
+ new_width = int(original_width * scale_factor)
249
+ new_height = int(original_height * scale_factor)
250
+ resized_image = cv2.resize(original_image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
251
+ else:
252
+ resized_image = original_image.copy()
253
+
254
+ # セグメンテーション
255
+ predicted_seg = segment_with_tiling(resized_image, processor, model, tile_size, tile_overlap, use_tiling)
256
+
257
+ # クラスマッピング
258
+ city_segmentation = np.zeros(predicted_seg.shape, dtype=np.uint8)
259
+ id2label = model.config.id2label
260
+
261
+ for class_id in np.unique(predicted_seg):
262
+ city_category = map_ade20k_to_city(class_id, id2label)
263
+ semantic_id = CITY_CATEGORIES[city_category]['semantic_id']
264
+ mask = predicted_seg == class_id
265
+ city_segmentation[mask] = semantic_id
266
+
267
+ # クラス平滑化
268
+ if apply_smoothing:
269
+ for _ in range(smooth_iter):
270
+ city_segmentation = median_filter(city_segmentation, size=3)
271
+
272
+ # 境界検出
273
+ if detect_bound:
274
+ city_segmentation, boundary_mask = detect_boundaries(city_segmentation, bound_thickness)
275
+
276
+ # セグメント抽出
277
+ segments_data = []
278
+ segment_id = 0
279
+
280
+ for cat_name, cat_info in CITY_CATEGORIES.items():
281
+ semantic_id = cat_info['semantic_id']
282
+ mask = city_segmentation == semantic_id
283
+
284
+ if not mask.any():
285
+ continue
286
+
287
+ if apply_morphology:
288
+ kernel = np.ones((morph_kernel, morph_kernel), dtype=bool)
289
+ mask = binary_opening(mask, structure=kernel)
290
+ mask = binary_closing(mask, structure=kernel)
291
+
292
+ labeled, num_features = ndimage.label(mask)
293
+
294
+ for i in range(1, num_features + 1):
295
+ segment_mask = labeled == i
296
+ area = np.sum(segment_mask)
297
+
298
+ if area < min_area:
299
+ continue
300
+
301
+ rows, cols = np.where(segment_mask)
302
+ if len(rows) == 0:
303
+ continue
304
+
305
+ y_min, y_max = rows.min(), rows.max()
306
+ x_min, x_max = cols.min(), cols.max()
307
+
308
+ segments_data.append({
309
+ 'id': segment_id,
310
+ 'category': cat_name,
311
+ 'label': cat_info['label'],
312
+ 'semantic_id': semantic_id,
313
+ 'color': cat_info['color'],
314
+ 'height': cat_info['height'],
315
+ 'segmentation': segment_mask,
316
+ 'bbox': [int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)],
317
+ 'area': int(area)
318
+ })
319
+ segment_id += 1
320
+
321
+ # 3Dメッシュ生成
322
+ meshes = create_3d_mesh(segments_data, resized_image, mesh_res)
323
+
324
+ # メタデータ
325
+ metadata = {
326
+ 'version': '2.1',
327
+ 'total_segments': len(meshes),
328
+ 'categories': {}
329
+ }
330
+
331
+ for mesh in meshes:
332
+ cat = mesh['category']
333
+ if cat not in metadata['categories']:
334
+ metadata['categories'][cat] = {'label': mesh['label'], 'count': 0}
335
+ metadata['categories'][cat]['count'] += 1
336
+
337
+ # 可視化
338
+ colored_seg = create_colored_segmentation(city_segmentation)
339
+ overlay = (resized_image.astype(np.float32) * 0.5 + colored_seg.astype(np.float32) * 0.5).astype(np.uint8)
340
+
341
+ # JSONファイル作成
342
+ output_data = {'metadata': metadata, 'meshes': meshes}
343
+ json_str = json.dumps(output_data, ensure_ascii=False, indent=2)
344
+
345
+ # ZIPファイル作成
346
+ zip_buffer = io.BytesIO()
347
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
348
+ zip_file.writestr('city_3d_model.json', json_str)
349
+
350
+ # セグメンテーション画像
351
+ _, buffer = cv2.imencode('.png', cv2.cvtColor(colored_seg, cv2.COLOR_RGB2BGR))
352
+ zip_file.writestr('segmentation_result.png', buffer.tobytes())
353
+
354
+ zip_buffer.seek(0)
355
+
356
+ # 統計情報
357
+ stats = f"総セグメント数: {len(meshes)}\n\n"
358
+ for cat, info in metadata['categories'].items():
359
+ stats += f"{info['label']}: {info['count']}個\n"
360
+
361
+ return colored_seg, overlay, stats, zip_buffer.getvalue()
362
+
363
+ # Gradio UI
364
+ with gr.Blocks(title="3D City Map Generator") as demo:
365
+ gr.Markdown("# 🏙️ 3D City Map Generator")
366
+ gr.Markdown("航空写真から3D都市マップを生成します(Segformer B5モデル使用)")
367
+
368
+ with gr.Row():
369
+ with gr.Column():
370
+ input_image = gr.Image(label="航空写真をアップロード", type="numpy")
371
+
372
+ with gr.Accordion("⚙️ 詳細設定", open=False):
373
+ max_size = gr.Slider(640, 2048, value=800, step=64, label="最大画像サイズ")
374
+ use_tiling = gr.Checkbox(value=True, label="タイル分割処理を使用")
375
+ tile_size = gr.Slider(120, 640, value=320, step=40, label="タイルサイズ")
376
+ tile_overlap = gr.Slider(32, 128, value=64, step=8, label="タイル重複")
377
+ min_area = gr.Slider(20, 200, value=32, step=4, label="最小セグメント面積")
378
+ mesh_res = gr.Slider(1, 4, value=3, step=1, label="メッシュ解像度")
379
+
380
+ apply_morphology = gr.Checkbox(value=True, label="モルフォロジー処理")
381
+ morph_kernel = gr.Slider(3, 9, value=7, step=2, label="モルフォロジーカーネル")
382
+
383
+ detect_bound = gr.Checkbox(value=True, label="境界検出")
384
+ bound_thickness = gr.Slider(1, 5, value=5, step=1, label="境界の太さ")
385
+
386
+ apply_smoothing = gr.Checkbox(value=True, label="クラス平滑化")
387
+ smooth_iter = gr.Slider(1, 3, value=2, step=1, label="平滑化反復回数")
388
+
389
+ process_btn = gr.Button("🚀 3Dマップ生成", variant="primary")
390
+
391
+ with gr.Column():
392
+ seg_output = gr.Image(label="セグメンテーション結果")
393
+ overlay_output = gr.Image(label="オーバーレイ")
394
+ stats_output = gr.Textbox(label="統計情報", lines=10)
395
+ download_output = gr.File(label="📥 3Dモデルをダウンロード (ZIP)")
396
+
397
+ process_btn.click(
398
+ fn=process_image,
399
+ inputs=[input_image, max_size, tile_size, tile_overlap, min_area, mesh_res,
400
+ apply_morphology, morph_kernel, detect_bound, bound_thickness,
401
+ apply_smoothing, smooth_iter, use_tiling],
402
+ outputs=[seg_output, overlay_output, stats_output, download_output]
403
+ )
404
+
405
+ gr.Markdown("""
406
+ ### 使い方
407
+ 1. 航空写真をアップロード
408
+ 2. 必要に応じてパラメータを調整
409
+ 3. 「3Dマップ生成」をクリック
410
+ 4. ZIPファイルをダウンロードして、JSONファイルをBlenderなどで使用
411
+
412
+ ### パラメータ説明
413
+ - **最大画像サイズ**: 大きいほど精度向上(処理時間増加)
414
+ - **タイル分割**: 大きな画像の精度向上に重要
415
+ - **最小セグメント面積**: 増やすとノイズ削減
416
+ - **メッシュ解像度**: 増やすとファイルサイズ減少
417
+ - **境界検出**: 建物と道路の混合を防ぐ
418
+ """)
419
+
420
+ if __name__ == "__main__":
421
+ demo.launch()