tmdwo commited on
Commit
28d939c
ยท
verified ยท
1 Parent(s): 182b195

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +57 -14
  2. app.py +681 -0
  3. requirements.txt +14 -0
README.md CHANGED
@@ -1,14 +1,57 @@
1
- ---
2
- title: SAM3LayerSegmentationTool
3
- emoji: ๐Ÿจ
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Layer-based object separation and area analysis tool | 1. Cr
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SAM3 Layer Segmentation Tool
3
+ emoji: ๐ŸŽจ
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Layer-based object segmentation and area analysis with SAM3
12
+ ---
13
+
14
+ # SAM3 Layer Segmentation Tool ๐ŸŽจ
15
+
16
+ **๋ ˆ์ด์–ด ๊ธฐ๋ฐ˜ ๊ฐ์ฒด ๋ถ„๋ฆฌ ๋ฐ ๋ฉด์  ๋ถ„์„ ๋„๊ตฌ**
17
+
18
+ SAM3 (Segment Anything Model 3)๋ฅผ ํ™œ์šฉํ•œ ์ง๊ด€์ ์ธ ์ด๋ฏธ์ง€ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋„๊ตฌ์ž…๋‹ˆ๋‹ค.
19
+
20
+ ## โœจ ์ฃผ์š” ๊ธฐ๋Šฅ
21
+
22
+ - ๐ŸŽฏ **๋ ˆ์ด์–ด ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜**: ์—ฌ๋Ÿฌ ๊ฐ์ฒด๋ฅผ ๋ ˆ์ด์–ด๋ณ„๋กœ ๋ถ„๋ฆฌํ•˜์—ฌ ๊ด€๋ฆฌ
23
+ - ๐Ÿ–ฑ๏ธ **Include/Exclude ํฌ์ธํŠธ**: ๋นจ๊ฐ„์ƒ‰ ์ (ํฌํ•จ) ๋ฐ ํŒŒ๋ž€์ƒ‰ ์ (์ œ์™ธ)์œผ๋กœ ์ •๋ฐ€ํ•œ ์˜์—ญ ์„ ํƒ
24
+ - ๐Ÿ“Š **๋ฉด์  ๋ถ„์„**: ๊ฐ ๋ ˆ์ด์–ด์˜ ํ”ฝ์…€ ์ˆ˜ ๋ฐ ๋น„์œจ ์ž๋™ ๊ณ„์‚ฐ
25
+ - ๐ŸŽจ **์‹œ๊ฐํ™” ์„ค์ •**: ํˆฌ๋ช…๋„, ํ…Œ๋‘๋ฆฌ ๋‘๊ป˜ ์กฐ์ • ๊ฐ€๋Šฅ
26
+ - ๐Ÿ’พ **๋‹ค์ค‘ ๋ ˆ์ด์–ด ๊ด€๋ฆฌ**: ์—ฌ๋Ÿฌ ๊ฐ์ฒด๋ฅผ ๋™์‹œ์— ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜
27
+
28
+ ## ๐Ÿš€ ์‚ฌ์šฉ ๋ฐฉ๋ฒ•
29
+
30
+ 1. **์ด๋ฏธ์ง€ ์—…๋กœ๋“œ**: ์ขŒ์ธก์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค
31
+ 2. **๋ ˆ์ด์–ด ์ƒ์„ฑ**: ๋ถ„๋ฆฌํ•˜๊ณ  ์‹ถ์€ ๊ฐ์ฒด๋งˆ๋‹ค ๋ ˆ์ด์–ด๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค (์˜ˆ: ๋ฒค์น˜, ๋‚˜๋ฌด, ์‚ฌ๋žŒ)
32
+ 3. **ํฌ์ธํŠธ ๋ชจ๋“œ ์„ ํƒ**: Include Point (๋นจ๊ฐ•) ๋˜๋Š” Exclude Point (ํŒŒ๋ž‘) ์„ ํƒ
33
+ 4. **ํฌ์ธํŠธ ์ถ”๊ฐ€**: ์ด๋ฏธ์ง€๋ฅผ ํด๋ฆญํ•˜์—ฌ ํฌ์ธํŠธ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค
34
+ 5. **์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹คํ–‰**: "Run All Segmentation" ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ฉ๋‹ˆ๋‹ค
35
+ 6. **๊ฒฐ๊ณผ ํ™•์ธ**: ์šฐ์ธก์—์„œ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€์™€ ๋ฉด์  ๋ถ„์„ ํ…Œ์ด๋ธ”์„ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค
36
+
37
+ ## ๐ŸŽฏ ํฌ์ธํŠธ ๋ชจ๋“œ
38
+
39
+ - **Include Point (๋นจ๊ฐ„ ์› โ—)**: ์ด ์˜์—ญ์„ ํฌํ•จ์‹œํ‚ต๋‹ˆ๋‹ค
40
+ - **Exclude Point (ํŒŒ๋ž€ X)**: ์ด ์˜์—ญ์„ ์ œ์™ธํ•ฉ๋‹ˆ๋‹ค
41
+
42
+ ## ๐Ÿ“Š ๊ธฐ์ˆ  ์Šคํƒ
43
+
44
+ - **SAM3**: Meta์˜ ์ตœ์‹  ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
45
+ - **Gradio**: ์›น ์ธํ„ฐํŽ˜์ด์Šค
46
+ - **PyTorch**: ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ
47
+
48
+ ## โšก ์„ฑ๋Šฅ
49
+
50
+ - CPU ๋ชจ๋“œ ์ง€์› (GPU ๊ถŒ์žฅ)
51
+ - ์ฒซ ์‹คํ–‰ ์‹œ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ๋กœ ์‹œ๊ฐ„ ์†Œ์š”
52
+
53
+ ## ๐Ÿ“ License
54
+
55
+ Apache 2.0
56
+
57
+ tmd9564@gmail.com
app.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import tempfile
4
+ import spaces
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ from PIL import Image, ImageDraw
12
+ from typing import Iterable
13
+ from gradio.themes import Soft
14
+ from gradio.themes.utils import colors, fonts, sizes
15
+ from transformers import (
16
+ Sam3Model, Sam3Processor,
17
+ Sam3TrackerModel, Sam3TrackerProcessor
18
+ )
19
+
20
+ # ============ THEME SETUP ============
21
+ colors.steel_blue = colors.Color(
22
+ name="steel_blue",
23
+ c50="#EBF3F8",
24
+ c100="#D3E5F0",
25
+ c200="#A8CCE1",
26
+ c300="#7DB3D2",
27
+ c400="#529AC3",
28
+ c500="#4682B4",
29
+ c600="#3E72A0",
30
+ c700="#36638C",
31
+ c800="#2E5378",
32
+ c900="#264364",
33
+ c950="#1E3450",
34
+ )
35
+
36
+ class CustomBlueTheme(Soft):
37
+ def __init__(
38
+ self,
39
+ *,
40
+ primary_hue: colors.Color | str = colors.gray,
41
+ secondary_hue: colors.Color | str = colors.steel_blue,
42
+ neutral_hue: colors.Color | str = colors.slate,
43
+ text_size: sizes.Size | str = sizes.text_lg,
44
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
45
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
46
+ ),
47
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
48
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
49
+ ),
50
+ ):
51
+ super().__init__(
52
+ primary_hue=primary_hue,
53
+ secondary_hue=secondary_hue,
54
+ neutral_hue=neutral_hue,
55
+ text_size=text_size,
56
+ font=font,
57
+ font_mono=font_mono,
58
+ )
59
+ super().set(
60
+ background_fill_primary="*primary_50",
61
+ background_fill_primary_dark="*primary_900",
62
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
63
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
64
+ button_primary_text_color="white",
65
+ button_primary_text_color_hover="white",
66
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
67
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
68
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
69
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
70
+ slider_color="*secondary_500",
71
+ slider_color_dark="*secondary_600",
72
+ block_title_text_weight="600",
73
+ block_border_width="3px",
74
+ block_shadow="*shadow_drop_lg",
75
+ button_primary_shadow="*shadow_drop_lg",
76
+ button_large_padding="11px",
77
+ color_accent_soft="*primary_100",
78
+ block_label_background_fill="*primary_200",
79
+ )
80
+
81
+ app_theme = CustomBlueTheme()
82
+
83
+ # ============ GLOBAL SETUP ============
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ print(f"๐Ÿ–ฅ๏ธ Using compute device: {device}")
86
+
87
+ # Load models
88
+ print("โณ Loading SAM3 Models permanently into memory...")
89
+ try:
90
+ # ์˜คํ”„๋ผ์ธ ๋ชจ๋“œ๋กœ ์บ์‹œ์—์„œ ๋กœ๋“œ ์‹œ๋„
91
+ print(" ... Loading from local cache (offline mode)")
92
+ IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32)
93
+ IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3", local_files_only=True)
94
+
95
+ TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32)
96
+ TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3", local_files_only=True)
97
+
98
+ print("โœ… All Models loaded successfully from local cache!")
99
+ except Exception as e:
100
+ print(f"โŒ Cache loading failed: {e}")
101
+ print(" Trying online loading...")
102
+ try:
103
+ IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32)
104
+ IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
105
+
106
+ TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32)
107
+ TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
108
+
109
+ print("โœ… All Models loaded successfully (CPU mode)!")
110
+ except Exception as e2:
111
+ print(f"โŒ Online loading also failed: {e2}")
112
+ IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = None
113
+
114
+ # ============ LAYER MANAGEMENT ============
115
+ class LayerManager:
116
+ """๋ ˆ์ด์–ด ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๊ด€๋ฆฌ ํด๋ž˜์Šค"""
117
+ def __init__(self):
118
+ self.layers = {} # layer_id -> {'name': str, 'color': tuple, 'points': list, 'point_labels': list, 'masks': list, 'area': float}
119
+ self.current_layer_id = None
120
+ self.layer_counter = 0
121
+
122
+ def create_layer(self, name, color=None):
123
+ """์ƒˆ ๋ ˆ์ด์–ด ์ƒ์„ฑ"""
124
+ if color is None:
125
+ # ๋ฌด์ž‘์œ„ ์ƒ‰์ƒ ์ƒ์„ฑ
126
+ import random
127
+ color = (random.randint(50, 200), random.randint(50, 200), random.randint(50, 200))
128
+
129
+ layer_id = f"layer_{self.layer_counter}"
130
+ self.layers[layer_id] = {
131
+ 'name': name,
132
+ 'color': color,
133
+ 'points': [],
134
+ 'point_labels': [], # 1: positive, 0: negative
135
+ 'masks': [],
136
+ 'area': 0.0
137
+ }
138
+ self.layer_counter += 1
139
+ return layer_id
140
+
141
+ def add_point_to_layer(self, layer_id, point, label=1):
142
+ """๋ ˆ์ด์–ด์— ํฌ์ธํŠธ ์ถ”๊ฐ€"""
143
+ if layer_id in self.layers:
144
+ self.layers[layer_id]['points'].append(point)
145
+ self.layers[layer_id]['point_labels'].append(label)
146
+ print(f"[add_point_to_layer] Added point to '{self.layers[layer_id]['name']}': {point}, label={label}")
147
+ print(f"[add_point_to_layer] Total points in '{self.layers[layer_id]['name']}': {len(self.layers[layer_id]['points'])}")
148
+
149
+ def add_mask_to_layer(self, layer_id, mask):
150
+ """๋ ˆ์ด์–ด์— ๋งˆ์Šคํฌ ์ถ”๊ฐ€"""
151
+ if layer_id in self.layers:
152
+ # ๊ธฐ์กด ๋งˆ์Šคํฌ๋ฅผ ๊ต์ฒด (๊ฐ™์€ ๋ ˆ์ด์–ด์— ์žฌ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹œ)
153
+ self.layers[layer_id]['masks'] = [mask]
154
+
155
+ # ๋ฉด์  ๊ณ„์‚ฐ - mask๋ฅผ numpy array๋กœ ๋ณ€ํ™˜
156
+ if isinstance(mask, torch.Tensor):
157
+ mask_np = mask.cpu().numpy()
158
+ else:
159
+ mask_np = mask
160
+
161
+ # ๋ฉด์  ๊ณ„์‚ฐ
162
+ area = np.sum(mask_np > 0)
163
+ self.layers[layer_id]['area'] = area
164
+
165
+ # ๋””๋ฒ„๊น…: ๋งˆ์Šคํฌ ์ •๋ณด ์ถœ๋ ฅ
166
+ print(f"[add_mask_to_layer] Layer: {self.layers[layer_id]['name']}, Mask shape: {mask_np.shape}, Area: {area}")
167
+
168
+ def get_current_layer(self):
169
+ """ํ˜„์žฌ ์„ ํƒ๋œ ๋ ˆ์ด์–ด ๋ฐ˜ํ™˜"""
170
+ if self.current_layer_id and self.current_layer_id in self.layers:
171
+ return self.layers[self.current_layer_id]
172
+ return None
173
+
174
+ def set_current_layer(self, layer_id):
175
+ """ํ˜„์žฌ ๋ ˆ์ด์–ด ์„ค์ •"""
176
+ self.current_layer_id = layer_id
177
+
178
+ def clear_current_layer(self):
179
+ """ํ˜„์žฌ ๋ ˆ์ด์–ด ์ดˆ๊ธฐํ™”"""
180
+ if self.current_layer_id and self.current_layer_id in self.layers:
181
+ self.layers[self.current_layer_id]['points'] = []
182
+ self.layers[self.current_layer_id]['point_labels'] = []
183
+ self.layers[self.current_layer_id]['masks'] = []
184
+ self.layers[self.current_layer_id]['area'] = 0.0
185
+
186
+ def calculate_total_area_ratio(layer_manager, total_pixels):
187
+ """์ „์ฒด ์ด๋ฏธ์ง€ ๋Œ€๋น„ ๊ฐ ๋ ˆ์ด์–ด์˜ ๋ฉด์  ๋น„์œจ ๊ณ„์‚ฐ"""
188
+ ratios = []
189
+ for layer_id, layer in layer_manager.layers.items():
190
+ area = layer['area']
191
+ ratio = (area / total_pixels) * 100 if total_pixels > 0 and area > 0 else 0
192
+ has_mask = len(layer['masks']) > 0
193
+
194
+ # ๋””๋ฒ„๊น…: ๋ ˆ์ด์–ด ์ •๋ณด ์ถœ๋ ฅ
195
+ print(f"[calculate_total_area_ratio] Layer: {layer['name']}, Area: {area}, Ratio: {ratio}%, Masks: {len(layer['masks'])}, Has mask: {has_mask}")
196
+
197
+ ratios.append({
198
+ 'layer_name': layer['name'],
199
+ 'area_pixels': int(area),
200
+ 'ratio_percent': round(ratio, 2)
201
+ })
202
+ return ratios
203
+
204
+ def create_area_chart_data(ratios):
205
+ """๋ฉด์  ๋ฐ์ดํ„ฐ๋ฅผ ํ…Œ์ด๋ธ” ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜"""
206
+ if not ratios:
207
+ return pd.DataFrame(columns=["Layer", "Area (pixels)", "Ratio(%)"])
208
+
209
+ data = []
210
+ for ratio in ratios:
211
+ data.append({
212
+ "Layer": ratio['layer_name'],
213
+ "Area (pixels)": f"{ratio['area_pixels']:,}",
214
+ "Ratio(%)": f"{ratio['ratio_percent']}%"
215
+ })
216
+
217
+ return pd.DataFrame(data)
218
+
219
+ # ============ UTILITY FUNCTIONS ============
220
+ def compose_all_layers(base_image, layer_manager, opacity=0.5, border_width=2):
221
+ """๋ชจ๋“  ๋ ˆ์ด์–ด๋ฅผ ํ•ฉ์„ฑํ•˜์—ฌ ์ตœ์ข… ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
222
+ if isinstance(base_image, np.ndarray):
223
+ base_image = Image.fromarray(base_image)
224
+ base_image = base_image.convert("RGBA")
225
+
226
+ if not layer_manager.layers:
227
+ return base_image.convert("RGB")
228
+
229
+ composite_layer = Image.new("RGBA", base_image.size, (0, 0, 0, 0))
230
+
231
+ for layer_id, layer in layer_manager.layers.items():
232
+ if not layer['masks']:
233
+ continue
234
+
235
+ layer_color = layer['color']
236
+
237
+ for mask in layer['masks']:
238
+ if isinstance(mask, torch.Tensor):
239
+ mask = mask.cpu().numpy()
240
+ mask = mask.astype(np.uint8)
241
+
242
+ if mask.ndim == 3: mask = mask[0]
243
+ if mask.ndim == 2 and mask.shape[0] == 1: mask = mask[0]
244
+
245
+ # ๋งˆ์Šคํฌ๋ฅผ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
246
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
247
+
248
+ # ์ƒ‰์ƒ ๋ ˆ์ด์–ด ์ƒ์„ฑ
249
+ color_layer = Image.new("RGBA", base_image.size, layer_color + (0,))
250
+ mask_alpha = mask_img.point(lambda v: int(v * opacity * 255) if v > 0 else 0)
251
+ color_layer.putalpha(mask_alpha)
252
+
253
+ # ํ…Œ๋‘๋ฆฌ ์ถ”๊ฐ€
254
+ if border_width > 0:
255
+ try:
256
+ # ๋งˆ์Šคํฌ์˜ ํ…Œ๋‘๋ฆฌ ์ฐพ๊ธฐ
257
+ mask_np = np.array(mask_img)
258
+ kernel_size = border_width * 2 + 1
259
+ dilated = cv2.dilate(mask_np, np.ones((kernel_size, kernel_size), np.uint8))
260
+ border = dilated - mask_np
261
+ border_img = Image.fromarray(border)
262
+
263
+ border_layer = Image.new("RGBA", base_image.size, (255, 255, 255, 255)) # ํฐ์ƒ‰ ํ…Œ๋‘๋ฆฌ
264
+ border_alpha = border_img.point(lambda v: 255 if v > 0 else 0)
265
+ border_layer.putalpha(border_alpha)
266
+
267
+ # ํ…Œ๋‘๋ฆฌ๋ฅผ ๋จผ์ € ํ•ฉ์„ฑ
268
+ composite_layer = Image.alpha_composite(composite_layer, border_layer)
269
+ except Exception as e:
270
+ print(f"Border creation error: {e}")
271
+
272
+ # ๋งˆ์Šคํฌ ๋ ˆ์ด์–ด ํ•ฉ์„ฑ
273
+ composite_layer = Image.alpha_composite(composite_layer, color_layer)
274
+
275
+ # ์ตœ์ข… ํ•ฉ์„ฑ
276
+ final_result = Image.alpha_composite(base_image, composite_layer)
277
+ return final_result.convert("RGB")
278
+
279
+ def draw_points_on_image(image, layer_manager):
280
+ """์ด๋ฏธ์ง€์— ๋ชจ๋“  ๋ ˆ์ด์–ด์˜ ํฌ์ธํŠธ๋“ค์„ ํ‘œ์‹œ"""
281
+ if isinstance(image, np.ndarray):
282
+ image = Image.fromarray(image)
283
+
284
+ draw_img = image.copy()
285
+ draw = ImageDraw.Draw(draw_img)
286
+
287
+ for layer_id, layer in layer_manager.layers.items():
288
+ is_current = (layer_id == layer_manager.current_layer_id)
289
+
290
+ for i, point in enumerate(layer['points']):
291
+ x, y = point
292
+ label = layer['point_labels'][i]
293
+
294
+ # ํฌ์ง€ํ‹ฐ๋ธŒ: ๋นจ๊ฐ„์ƒ‰ ์›, ๋„ค๊ฑฐํ‹ฐ๋ธŒ: ํŒŒ๋ž€์ƒ‰ Xํ‘œ์‹œ
295
+ if label == 1: # Positive
296
+ # ํฐ ๋นจ๊ฐ„์ƒ‰ ์›
297
+ r = 15 if is_current else 10
298
+ draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=3)
299
+ # ์ž‘์€ ํฐ์ƒ‰ ์› (์ค‘์•™)
300
+ draw.ellipse((x-3, y-3, x+3, y+3), fill="white")
301
+ else: # Negative (0)
302
+ # ํฐ ํŒŒ๋ž€์ƒ‰ ์›
303
+ r = 15 if is_current else 10
304
+ draw.ellipse((x-r, y-r, x+r, y+r), fill="blue", outline="white", width=3)
305
+ # X ํ‘œ์‹œ
306
+ line_length = 8
307
+ draw.line([(x-line_length, y-line_length), (x+line_length, y+line_length)], fill="white", width=3)
308
+ draw.line([(x-line_length, y+line_length), (x+line_length, y-line_length)], fill="white", width=3)
309
+
310
+ return draw_img
311
+
312
+ # ============ UI FUNCTIONS ============
313
+ def create_new_layer(name, current_manager):
314
+ """์ƒˆ ๋ ˆ์ด์–ด ์ƒ์„ฑ"""
315
+ if not name.strip():
316
+ return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=[]), "Please enter a layer name"
317
+
318
+ # ์ค‘๋ณต ์ด๋ฆ„ ์ฒดํฌ
319
+ for layer_id, layer in current_manager.layers.items():
320
+ if layer['name'] == name.strip():
321
+ return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=[(layer['name'], lid) for lid, layer in current_manager.layers.items()]), f"Layer name '{name}' already exists"
322
+
323
+ layer_id = current_manager.create_layer(name.strip())
324
+ current_manager.set_current_layer(layer_id)
325
+
326
+ # ๋“œ๋กญ๋‹ค์šด ์„ ํƒ์ง€ ์—…๋ฐ์ดํŠธ
327
+ choices = [(layer['name'], lid) for lid, layer in current_manager.layers.items()]
328
+
329
+ return current_manager, create_layer_status_html(current_manager), gr.Dropdown(choices=choices, value=layer_id), f"Layer '{name}' created"
330
+
331
+ def create_layer_status_html(current_manager):
332
+ """๋ ˆ์ด์–ด ์ƒํƒœ ํ‘œ์‹œ HTML ์ƒ์„ฑ (์‹œ๊ฐ์  ํ‘œ์‹œ๋งŒ)"""
333
+ if not current_manager.layers:
334
+ return "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>"
335
+
336
+ html = "<div style='display: flex; flex-wrap: wrap; gap: 8px; padding: 10px;'>"
337
+
338
+ for layer_id, layer in current_manager.layers.items():
339
+ is_active = (current_manager.current_layer_id == layer_id)
340
+
341
+ # ์ƒ‰์ƒ ์ถ”์ถœ
342
+ r, g, b = layer['color']
343
+ color_hex = f"#{r:02x}{g:02x}{b:02x}"
344
+
345
+ # ํ™œ์„ฑํ™” ์ƒํƒœ์— ๋”ฐ๋ฅธ ์Šคํƒ€์ผ
346
+ if is_active:
347
+ style = f"""
348
+ background: linear-gradient(135deg, {color_hex}, {color_hex}dd);
349
+ color: white;
350
+ border: 3px solid #4682B4;
351
+ box-shadow: 0 4px 12px rgba(70, 130, 180, 0.4);
352
+ """
353
+ else:
354
+ style = f"""
355
+ background: linear-gradient(135deg, {color_hex}aa, {color_hex}77);
356
+ color: white;
357
+ border: 2px solid {color_hex};
358
+ opacity: 0.7;
359
+ """
360
+
361
+ # ํฌ์ธํŠธ ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ (ํฌ์ง€ํ‹ฐ๋ธŒ/๋„ค๊ฑฐํ‹ฐ๋ธŒ ๊ตฌ๋ถ„)
362
+ positive_points = sum(1 for label in layer['point_labels'] if label == 1)
363
+ negative_points = sum(1 for label in layer['point_labels'] if label == 0)
364
+ masks_count = len(layer['masks'])
365
+ has_mask = masks_count > 0
366
+
367
+ # ์ƒํƒœ ์•„์ด์ฝ˜
368
+ status_icon = "[OK]" if has_mask else "[ ]"
369
+
370
+ html += f"""
371
+ <div style="{style}
372
+ padding: 12px 20px;
373
+ border-radius: 8px;
374
+ font-weight: 600;
375
+ font-size: 14px;
376
+ min-width: 150px;">
377
+ {status_icon} {layer['name']}<br>
378
+ <small style='font-size: 11px; opacity: 0.9;'>
379
+ <span style='color: #ffcccc;'>+{positive_points}</span>
380
+ <span style='color: #ccccff;'>-{negative_points}</span>
381
+ {masks_count}mask
382
+ </small>
383
+ </div>
384
+ """
385
+
386
+ html += "</div>"
387
+ return html
388
+
389
+ def click_on_image(current_manager, image, point_mode, evt: gr.SelectData):
390
+ """์ด๋ฏธ์ง€ ํด๋ฆญ ์ฒ˜๋ฆฌ - Include/Exclude ๋ชจ๋“œ์— ๋”ฐ๋ผ ํฌ์ธํŠธ ์ถ”๊ฐ€"""
391
+ if image is None or current_manager.current_layer_id is None:
392
+ return image, current_manager, create_layer_status_html(current_manager), "Please select image and layer"
393
+
394
+ x, y = evt.index
395
+
396
+ # ํฌ์ธํŠธ ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋ ˆ์ด๋ธ” ๊ฒฐ์ • (positive=1, negative=0)
397
+ label = 1 if point_mode == "positive" else 0
398
+
399
+ layer_name = current_manager.layers[current_manager.current_layer_id]['name']
400
+ print(f"\n[click_on_image] ================")
401
+ print(f"[click_on_image] Layer: {layer_name}")
402
+ print(f"[click_on_image] Point mode: {point_mode}, Label: {label}, Position: ({x}, {y})")
403
+
404
+ current_manager.add_point_to_layer(current_manager.current_layer_id, [x, y], label)
405
+
406
+ # ํฌ์ธํŠธ ํ‘œ์‹œ๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (์›๋ณธ ์ด๋ฏธ์ง€์— ํฌ์ธํŠธ ํ‘œ์‹œ)
407
+ result_image = draw_points_on_image(image, current_manager)
408
+
409
+ mode_text = "Include" if label == 1 else "Exclude"
410
+
411
+ return result_image, current_manager, create_layer_status_html(current_manager), f"{mode_text} point added to '{layer_name}' at ({x}, {y})"
412
+
413
+ def segment_all_layers(current_manager, image, opacity, border_width):
414
+ """๋ชจ๋“  ๋ ˆ์ด์–ด๋ฅผ ์ˆœ์„œ๋Œ€๋กœ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹คํ–‰"""
415
+ if image is None:
416
+ return None, current_manager, create_layer_status_html(current_manager), "Please upload an image", pd.DataFrame()
417
+
418
+ if not current_manager.layers:
419
+ return None, current_manager, create_layer_status_html(current_manager), "Please create layers first", pd.DataFrame()
420
+
421
+ try:
422
+ print(f"\n[segment_all_layers] Starting segmentation for all layers...")
423
+ segmented_count = 0
424
+ skipped_count = 0
425
+
426
+ # ๋ชจ๋“  ๋ ˆ์ด์–ด๋ฅผ ์ˆœํšŒํ•˜๋ฉฐ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜
427
+ for layer_id, layer in current_manager.layers.items():
428
+ layer_name = layer['name']
429
+
430
+ # ํฌ์ธํŠธ๊ฐ€ ์—†๋Š” ๋ ˆ์ด์–ด๋Š” ๊ฑด๋„ˆ๋›ฐ๊ธฐ
431
+ if not layer['points']:
432
+ print(f"[segment_all_layers] Skipping '{layer_name}' - no points")
433
+ skipped_count += 1
434
+ continue
435
+
436
+ print(f"\n[segment_all_layers] Processing layer: {layer_name}")
437
+ print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}")
438
+
439
+ # SAM3 Tracker๋กœ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜
440
+ points_list = layer['points']
441
+ labels_list = layer['point_labels']
442
+
443
+ input_points = [[points_list]]
444
+ input_labels = [[labels_list]]
445
+
446
+ inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
447
+
448
+ with torch.no_grad():
449
+ outputs = TRK_MODEL(**inputs, multimask_output=False)
450
+
451
+ masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
452
+
453
+ # ๋ ˆ์ด์–ด์— ๋งˆ์Šคํฌ ์ถ”๊ฐ€
454
+ current_manager.add_mask_to_layer(layer_id, masks[0])
455
+ segmented_count += 1
456
+ print(f"[segment_all_layers] Completed '{layer_name}'")
457
+
458
+ # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (ํฌ์ธํŠธ ํฌํ•จ)
459
+ result_image = compose_all_layers(image, current_manager, opacity, border_width)
460
+ result_image = draw_points_on_image(result_image, current_manager)
461
+
462
+ # ๋ฉด์  ๋ถ„์„
463
+ total_pixels = image.size[0] * image.size[1]
464
+ ratios = calculate_total_area_ratio(current_manager, total_pixels)
465
+ chart_data = create_area_chart_data(ratios)
466
+
467
+ status_msg = f"Segmentation completed! Processed: {segmented_count} layers, Skipped: {skipped_count} layers"
468
+ print(f"\n[segment_all_layers] {status_msg}")
469
+
470
+ return result_image, current_manager, create_layer_status_html(current_manager), status_msg, chart_data
471
+
472
+ except Exception as e:
473
+ import traceback
474
+ print(f"[segment_all_layers] Error: {str(e)}")
475
+ traceback.print_exc()
476
+ return None, current_manager, create_layer_status_html(current_manager), f"Error: {str(e)}", pd.DataFrame()
477
+
478
+ def clear_current_layer(current_manager, image, opacity, border_width):
479
+ """ํ˜„์žฌ ๋ ˆ์ด์–ด ์ดˆ๊ธฐํ™”"""
480
+ if current_manager.current_layer_id:
481
+ current_manager.clear_current_layer()
482
+
483
+ if image:
484
+ result_image = compose_all_layers(image, current_manager, opacity, border_width)
485
+ result_image = draw_points_on_image(result_image, current_manager)
486
+ else:
487
+ result_image = None
488
+
489
+ total_pixels = image.size[0] * image.size[1] if image else 0
490
+ ratios = calculate_total_area_ratio(current_manager, total_pixels)
491
+ chart_data = create_area_chart_data(ratios)
492
+
493
+ return result_image, current_manager, create_layer_status_html(current_manager), "Layer cleared", chart_data
494
+
495
+ return None, current_manager, create_layer_status_html(current_manager), "Please select a layer", pd.DataFrame()
496
+
497
+ def refresh_visualization(current_manager, image, opacity, border_width):
498
+ """์‹œ๊ฐํ™” ์ƒˆ๋กœ๊ณ ์นจ"""
499
+ if image is None:
500
+ return None, "Please upload an image", pd.DataFrame()
501
+
502
+ result_image = compose_all_layers(image, current_manager, opacity, border_width)
503
+ result_image = draw_points_on_image(result_image, current_manager)
504
+
505
+ total_pixels = image.size[0] * image.size[1]
506
+ ratios = calculate_total_area_ratio(current_manager, total_pixels)
507
+ chart_data = create_area_chart_data(ratios)
508
+
509
+ return result_image, "Visualization updated", chart_data
510
+
511
+
512
+ # ============ GRADIO INTERFACE ============
513
+ custom_css="""
514
+ #col-container { margin: 0 auto; max-width: 1200px; }
515
+ #main-title h1 { font-size: 2.1em !important; }
516
+ .layer-button { margin: 2px; }
517
+ """
518
+
519
+ # No custom JavaScript needed anymore
520
+ custom_js = ""
521
+
522
+ # ์ „์—ญ ๋ ˆ์ด์–ด ๋งค๋‹ˆ์ €
523
+ layer_manager = LayerManager()
524
+
525
+ with gr.Blocks() as demo:
526
+ with gr.Column(elem_id="col-container"):
527
+ gr.Markdown("# **SAM3 Layer Segmentation Tool**", elem_id="main-title")
528
+ gr.Markdown("**Layer-based object separation and area analysis tool** | 1. Create layers 2. Select point mode and click 3. Run segmentation (processes all layers)")
529
+
530
+ with gr.Row():
531
+ with gr.Column(scale=1):
532
+ img_input = gr.Image(type="pil", label="Upload Image", interactive=True, height=400)
533
+
534
+ # ๋ ˆ์ด์–ด ์ƒ์„ฑ
535
+ with gr.Row():
536
+ layer_name_input = gr.Textbox(label="Layer Name", placeholder="e.g. bench, tree, person")
537
+ create_layer_btn = gr.Button("Create", variant="primary")
538
+
539
+ # ๋ ˆ์ด์–ด ์ƒํƒœ ํ‘œ์‹œ
540
+ gr.Markdown("### Layers Status")
541
+ layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>")
542
+
543
+ # ๋ ˆ์ด์–ด ์„ ํƒ
544
+ layer_selector = gr.Dropdown(label="Select Layer to Add Points", choices=[], interactive=True)
545
+
546
+ # ํฌ์ธํŠธ ๋ชจ๋“œ ์„ ํƒ
547
+ gr.Markdown("### Point Mode")
548
+ with gr.Row():
549
+ include_btn = gr.Button("Include Point", variant="primary", size="sm")
550
+ exclude_btn = gr.Button("Exclude Point", variant="secondary", size="sm")
551
+
552
+ point_mode_text = gr.Textbox(label="Current Mode", value="Include Point (Red)", interactive=False)
553
+
554
+ # ํฌ์ธํŠธ ์•ˆ๋‚ด
555
+ gr.Markdown("""
556
+ **Instructions:**
557
+ - Select a layer from dropdown
558
+ - Choose point mode (Include/Exclude)
559
+ - Click on image to add point
560
+ - **Red circle (โ—)**: Include this area
561
+ - **Blue circle with X**: Exclude this area
562
+ """)
563
+
564
+ # ์ปจํŠธ๋กค
565
+ with gr.Row():
566
+ segment_btn = gr.Button("Run All Segmentation", variant="primary", size="lg")
567
+ clear_btn = gr.Button("Clear Current Layer", variant="secondary")
568
+
569
+ # ์ƒํƒœ
570
+ status_text = gr.Textbox(label="Status", interactive=False)
571
+ st_layer_manager = gr.State(layer_manager)
572
+ point_mode_state = gr.State("positive") # "positive" or "negative"
573
+
574
+ with gr.Column(scale=2):
575
+ img_output = gr.Image(type="pil", label="Segmentation Result", height=400, interactive=False)
576
+
577
+ # ๋ฉด์  ํ…Œ์ด๋ธ”
578
+ area_table = gr.Dataframe(
579
+ label="Area Ratio by Layer",
580
+ headers=["Layer", "Area (pixels)", "Ratio(%)"],
581
+ datatype=["str", "str", "str"],
582
+ interactive=False,
583
+ wrap=True
584
+ )
585
+
586
+ # ์„ค์ •
587
+ with gr.Accordion("Visualization Settings", open=False):
588
+ opacity_slider = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Mask Opacity")
589
+ border_slider = gr.Slider(0, 5, value=2, step=1, label="Border Width")
590
+
591
+ # ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
592
+ create_layer_btn.click(
593
+ create_new_layer,
594
+ inputs=[layer_name_input, st_layer_manager],
595
+ outputs=[st_layer_manager, layer_buttons_html, layer_selector, status_text]
596
+ )
597
+
598
+ # ๋ ˆ์ด์–ด ์„ ํƒ
599
+ def on_layer_select(layer_id, mgr):
600
+ if layer_id:
601
+ mgr.set_current_layer(layer_id)
602
+ return mgr, create_layer_status_html(mgr), f"Layer '{mgr.layers[layer_id]['name']}' selected"
603
+ return mgr, create_layer_status_html(mgr), "Please select a layer"
604
+
605
+ layer_selector.change(
606
+ on_layer_select,
607
+ inputs=[layer_selector, st_layer_manager],
608
+ outputs=[st_layer_manager, layer_buttons_html, status_text]
609
+ )
610
+
611
+ # ํฌ์ธํŠธ ๋ชจ๋“œ ๋ณ€๊ฒฝ
612
+ def set_include_mode():
613
+ return "positive", "Include Point (Red)"
614
+
615
+ def set_exclude_mode():
616
+ return "negative", "Exclude Point (Blue)"
617
+
618
+ include_btn.click(
619
+ set_include_mode,
620
+ outputs=[point_mode_state, point_mode_text]
621
+ )
622
+
623
+ exclude_btn.click(
624
+ set_exclude_mode,
625
+ outputs=[point_mode_state, point_mode_text]
626
+ )
627
+
628
+ # ์ด๋ฏธ์ง€ ํด๋ฆญ ์ด๋ฒคํŠธ - img_input๊ณผ img_output ๋ชจ๋‘์—์„œ ํด๋ฆญ ๋ฐ›๊ธฐ
629
+ img_input.select(
630
+ click_on_image,
631
+ inputs=[st_layer_manager, img_input, point_mode_state],
632
+ outputs=[img_output, st_layer_manager, layer_buttons_html, status_text]
633
+ )
634
+
635
+ img_output.select(
636
+ click_on_image,
637
+ inputs=[st_layer_manager, img_input, point_mode_state],
638
+ outputs=[img_output, st_layer_manager, layer_buttons_html, status_text]
639
+ )
640
+
641
+ # ๋ชจ๋“  ๋ ˆ์ด์–ด ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์‹คํ–‰
642
+ segment_btn.click(
643
+ segment_all_layers,
644
+ inputs=[st_layer_manager, img_input, opacity_slider, border_slider],
645
+ outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table]
646
+ )
647
+
648
+ clear_btn.click(
649
+ clear_current_layer,
650
+ inputs=[st_layer_manager, img_input, opacity_slider, border_slider],
651
+ outputs=[img_output, st_layer_manager, layer_buttons_html, status_text, area_table]
652
+ )
653
+
654
+ # ํˆฌ๋ช…๋„ ๋ฐ ํ…Œ๋‘๋ฆฌ ์Šฌ๋ผ์ด๋” ์‹ค์‹œ๊ฐ„ ์—…๋ฐ์ดํŠธ
655
+ opacity_slider.change(
656
+ refresh_visualization,
657
+ inputs=[st_layer_manager, img_input, opacity_slider, border_slider],
658
+ outputs=[img_output, status_text, area_table]
659
+ )
660
+
661
+ border_slider.change(
662
+ refresh_visualization,
663
+ inputs=[st_layer_manager, img_input, opacity_slider, border_slider],
664
+ outputs=[img_output, status_text, area_table]
665
+ )
666
+
667
+ # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ์‹œ ์ดˆ๊ธฐํ™”
668
+ def on_image_upload(img):
669
+ new_manager = LayerManager()
670
+ empty_html = "<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>"
671
+ # ์—…๋กœ๋“œํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ถœ๋ ฅ์—๋„ ํ‘œ์‹œ
672
+ return new_manager, img, pd.DataFrame(), empty_html, gr.Dropdown(choices=[], value=None), "positive", "Include Point (Red)", "New image uploaded"
673
+
674
+ img_input.change(
675
+ on_image_upload,
676
+ inputs=[img_input],
677
+ outputs=[st_layer_manager, img_output, area_table, layer_buttons_html, layer_selector, point_mode_state, point_mode_text, status_text]
678
+ )
679
+
680
+ if __name__ == "__main__":
681
+ demo.launch(show_error=True, theme=app_theme, css=custom_css)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git
2
+ sentencepiece
3
+ opencv-python-headless
4
+ imageio[pyav]
5
+ torchvision
6
+ matplotlib
7
+ accelerate
8
+ pillow
9
+ gradio
10
+ spaces
11
+ numpy
12
+ pandas
13
+ torch
14
+ peft