Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -21,32 +21,37 @@ from transformers import (
|
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print(f"๐ฅ๏ธ Using compute device: {device}")
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
print("
|
| 39 |
try:
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
|
| 42 |
|
| 43 |
-
TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=
|
| 44 |
TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
|
| 45 |
|
| 46 |
-
print("โ
All Models loaded successfully
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
# ============ LAYER MANAGEMENT ============
|
| 52 |
class LayerManager:
|
|
@@ -248,9 +253,12 @@ def draw_points_on_image(image, layer_manager):
|
|
| 248 |
|
| 249 |
# ============ UI FUNCTIONS ============
|
| 250 |
def update_layer_selector_choices(manager):
|
| 251 |
-
"""๋ ์ด์ด ์ ํ
|
| 252 |
-
choices = [
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
def create_new_layer(name, current_manager):
|
| 256 |
"""์ ๋ ์ด์ด ์์ฑ"""
|
|
@@ -384,6 +392,11 @@ def segment_all_layers(current_manager, image, opacity, border_width):
|
|
| 384 |
print(f"\n[segment_all_layers] Processing layer: {layer_name}")
|
| 385 |
print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}")
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
# SAM3 Tracker๋ก ์ธ๊ทธ๋ฉํ
์ด์
|
| 388 |
points_list = layer['points']
|
| 389 |
labels_list = layer['point_labels']
|
|
@@ -391,7 +404,9 @@ def segment_all_layers(current_manager, image, opacity, border_width):
|
|
| 391 |
input_points = [[points_list]]
|
| 392 |
input_labels = [[labels_list]]
|
| 393 |
|
| 394 |
-
|
|
|
|
|
|
|
| 395 |
|
| 396 |
with torch.no_grad():
|
| 397 |
outputs = TRK_MODEL(**inputs, multimask_output=False)
|
|
@@ -485,8 +500,8 @@ with gr.Blocks() as demo:
|
|
| 485 |
gr.Markdown("### Layers Status")
|
| 486 |
layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>")
|
| 487 |
|
| 488 |
-
# ๋ ์ด์ด ์ ํ
|
| 489 |
-
layer_selector = gr.
|
| 490 |
|
| 491 |
# ํฌ์ธํธ ๋ชจ๋ ์ ํ
|
| 492 |
gr.Markdown("### Point Mode")
|
|
@@ -537,13 +552,21 @@ with gr.Blocks() as demo:
|
|
| 537 |
)
|
| 538 |
|
| 539 |
# ๋ ์ด์ด ์ ํ
|
| 540 |
-
def on_layer_select(
|
| 541 |
if mgr is None:
|
| 542 |
mgr = LayerManager()
|
| 543 |
|
| 544 |
-
if
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
return mgr, create_layer_status_html(mgr), "Please select a layer"
|
| 548 |
|
| 549 |
layer_selector.change(
|
|
|
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print(f"๐ฅ๏ธ Using compute device: {device}")
|
| 23 |
|
| 24 |
+
# Models will be loaded lazily in functions to avoid build timeouts
|
| 25 |
+
IMG_MODEL = None
|
| 26 |
+
IMG_PROCESSOR = None
|
| 27 |
+
TRK_MODEL = None
|
| 28 |
+
TRK_PROCESSOR = None
|
| 29 |
+
|
| 30 |
+
@spaces.GPU
|
| 31 |
+
def load_models():
|
| 32 |
+
"""Lazy load models when needed"""
|
| 33 |
+
global IMG_MODEL, IMG_PROCESSOR, TRK_MODEL, TRK_PROCESSOR
|
| 34 |
+
|
| 35 |
+
if IMG_MODEL is not None:
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
print("โณ Loading SAM3 Models...")
|
| 39 |
try:
|
| 40 |
+
# GPU๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ๋ฉด GPU๋ก ๋ก๋
|
| 41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 43 |
+
|
| 44 |
+
IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype)
|
| 45 |
IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
|
| 46 |
|
| 47 |
+
TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype)
|
| 48 |
TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
|
| 49 |
|
| 50 |
+
print(f"โ
All Models loaded successfully on {device}!")
|
| 51 |
+
return True
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"โ Model loading failed: {e}")
|
| 54 |
+
return False
|
| 55 |
|
| 56 |
# ============ LAYER MANAGEMENT ============
|
| 57 |
class LayerManager:
|
|
|
|
| 253 |
|
| 254 |
# ============ UI FUNCTIONS ============
|
| 255 |
def update_layer_selector_choices(manager):
|
| 256 |
+
"""๋ ์ด์ด ์ ํ ๋ผ๋์ค ๋ฒํผ์ choices ์
๋ฐ์ดํธ"""
|
| 257 |
+
choices = [layer['name'] for layer in manager.layers.values()]
|
| 258 |
+
current_value = None
|
| 259 |
+
if manager.current_layer_id and manager.current_layer_id in manager.layers:
|
| 260 |
+
current_value = manager.layers[manager.current_layer_id]['name']
|
| 261 |
+
return gr.Radio(choices=choices, interactive=True, value=current_value)
|
| 262 |
|
| 263 |
def create_new_layer(name, current_manager):
|
| 264 |
"""์ ๋ ์ด์ด ์์ฑ"""
|
|
|
|
| 392 |
print(f"\n[segment_all_layers] Processing layer: {layer_name}")
|
| 393 |
print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}")
|
| 394 |
|
| 395 |
+
# Load models if needed
|
| 396 |
+
if not load_models():
|
| 397 |
+
print(f"[segment_all_layers] Failed to load models for layer: {layer_name}")
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
# SAM3 Tracker๋ก ์ธ๊ทธ๋ฉํ
์ด์
|
| 401 |
points_list = layer['points']
|
| 402 |
labels_list = layer['point_labels']
|
|
|
|
| 404 |
input_points = [[points_list]]
|
| 405 |
input_labels = [[labels_list]]
|
| 406 |
|
| 407 |
+
# Use the same device as the model
|
| 408 |
+
model_device = next(TRK_MODEL.parameters()).device
|
| 409 |
+
inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model_device)
|
| 410 |
|
| 411 |
with torch.no_grad():
|
| 412 |
outputs = TRK_MODEL(**inputs, multimask_output=False)
|
|
|
|
| 500 |
gr.Markdown("### Layers Status")
|
| 501 |
layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>")
|
| 502 |
|
| 503 |
+
# ๋ ์ด์ด ์ ํ (๋ผ๋์ค ๋ฒํผ์ผ๋ก ๋ณ๊ฒฝ)
|
| 504 |
+
layer_selector = gr.Radio(label="Select Layer to Add Points", choices=[], interactive=True)
|
| 505 |
|
| 506 |
# ํฌ์ธํธ ๋ชจ๋ ์ ํ
|
| 507 |
gr.Markdown("### Point Mode")
|
|
|
|
| 552 |
)
|
| 553 |
|
| 554 |
# ๋ ์ด์ด ์ ํ
|
| 555 |
+
def on_layer_select(selected_name, mgr):
|
| 556 |
if mgr is None:
|
| 557 |
mgr = LayerManager()
|
| 558 |
|
| 559 |
+
if selected_name:
|
| 560 |
+
# ์ด๋ฆ์ผ๋ก layer_id ์ฐพ๊ธฐ
|
| 561 |
+
layer_id = None
|
| 562 |
+
for lid, layer in mgr.layers.items():
|
| 563 |
+
if layer['name'] == selected_name:
|
| 564 |
+
layer_id = lid
|
| 565 |
+
break
|
| 566 |
+
|
| 567 |
+
if layer_id:
|
| 568 |
+
mgr.set_current_layer(layer_id)
|
| 569 |
+
return mgr, create_layer_status_html(mgr), f"Layer '{selected_name}' selected"
|
| 570 |
return mgr, create_layer_status_html(mgr), "Please select a layer"
|
| 571 |
|
| 572 |
layer_selector.change(
|