tmdwo commited on
Commit
4e5b950
ยท
verified ยท
1 Parent(s): 5fbd3a5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -21,32 +21,37 @@ from transformers import (
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"๐Ÿ–ฅ๏ธ Using compute device: {device}")
23
 
24
- # Load models
25
- print("โณ Loading SAM3 Models permanently into memory...")
26
- try:
27
- # ์˜คํ”„๋ผ์ธ ๋ชจ๋“œ๋กœ ์บ์‹œ์—์„œ ๋กœ๋“œ ์‹œ๋„
28
- print(" ... Loading from local cache (offline mode)")
29
- IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32)
30
- IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3", local_files_only=True)
31
-
32
- TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", local_files_only=True, device_map="cpu", torch_dtype=torch.float32)
33
- TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3", local_files_only=True)
34
-
35
- print("โœ… All Models loaded successfully from local cache!")
36
- except Exception as e:
37
- print(f"โŒ Cache loading failed: {e}")
38
- print(" Trying online loading...")
39
  try:
40
- IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32)
 
 
 
 
41
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
42
 
43
- TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map="cpu", torch_dtype=torch.float32)
44
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
45
 
46
- print("โœ… All Models loaded successfully (CPU mode)!")
47
- except Exception as e2:
48
- print(f"โŒ Online loading also failed: {e2}")
49
- IMG_MODEL = IMG_PROCESSOR = TRK_MODEL = TRK_PROCESSOR = None
 
50
 
51
  # ============ LAYER MANAGEMENT ============
52
  class LayerManager:
@@ -248,9 +253,12 @@ def draw_points_on_image(image, layer_manager):
248
 
249
  # ============ UI FUNCTIONS ============
250
  def update_layer_selector_choices(manager):
251
- """๋ ˆ์ด์–ด ์„ ํƒ ๋“œ๋กญ๋‹ค์šด์˜ choices ์—…๋ฐ์ดํŠธ"""
252
- choices = [(layer['name'], lid) for lid, layer in manager.layers.items()]
253
- return gr.Dropdown(choices=choices, interactive=True, value=manager.current_layer_id)
 
 
 
254
 
255
  def create_new_layer(name, current_manager):
256
  """์ƒˆ ๋ ˆ์ด์–ด ์ƒ์„ฑ"""
@@ -384,6 +392,11 @@ def segment_all_layers(current_manager, image, opacity, border_width):
384
  print(f"\n[segment_all_layers] Processing layer: {layer_name}")
385
  print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}")
386
 
 
 
 
 
 
387
  # SAM3 Tracker๋กœ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜
388
  points_list = layer['points']
389
  labels_list = layer['point_labels']
@@ -391,7 +404,9 @@ def segment_all_layers(current_manager, image, opacity, border_width):
391
  input_points = [[points_list]]
392
  input_labels = [[labels_list]]
393
 
394
- inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
 
 
395
 
396
  with torch.no_grad():
397
  outputs = TRK_MODEL(**inputs, multimask_output=False)
@@ -485,8 +500,8 @@ with gr.Blocks() as demo:
485
  gr.Markdown("### Layers Status")
486
  layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>")
487
 
488
- # ๋ ˆ์ด์–ด ์„ ํƒ
489
- layer_selector = gr.Dropdown(label="Select Layer to Add Points", choices=[], interactive=True, value=None)
490
 
491
  # ํฌ์ธํŠธ ๋ชจ๋“œ ์„ ํƒ
492
  gr.Markdown("### Point Mode")
@@ -537,13 +552,21 @@ with gr.Blocks() as demo:
537
  )
538
 
539
  # ๋ ˆ์ด์–ด ์„ ํƒ
540
- def on_layer_select(layer_id, mgr):
541
  if mgr is None:
542
  mgr = LayerManager()
543
 
544
- if layer_id:
545
- mgr.set_current_layer(layer_id)
546
- return mgr, create_layer_status_html(mgr), f"Layer '{mgr.layers[layer_id]['name']}' selected"
 
 
 
 
 
 
 
 
547
  return mgr, create_layer_status_html(mgr), "Please select a layer"
548
 
549
  layer_selector.change(
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"๐Ÿ–ฅ๏ธ Using compute device: {device}")
23
 
24
+ # Models will be loaded lazily in functions to avoid build timeouts
25
+ IMG_MODEL = None
26
+ IMG_PROCESSOR = None
27
+ TRK_MODEL = None
28
+ TRK_PROCESSOR = None
29
+
30
+ @spaces.GPU
31
+ def load_models():
32
+ """Lazy load models when needed"""
33
+ global IMG_MODEL, IMG_PROCESSOR, TRK_MODEL, TRK_PROCESSOR
34
+
35
+ if IMG_MODEL is not None:
36
+ return True
37
+
38
+ print("โณ Loading SAM3 Models...")
39
  try:
40
+ # GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด GPU๋กœ ๋กœ๋“œ
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ dtype = torch.float16 if device == "cuda" else torch.float32
43
+
44
+ IMG_MODEL = Sam3Model.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype)
45
  IMG_PROCESSOR = Sam3Processor.from_pretrained("DiffusionWave/sam3")
46
 
47
+ TRK_MODEL = Sam3TrackerModel.from_pretrained("DiffusionWave/sam3", device_map=device, torch_dtype=dtype)
48
  TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("DiffusionWave/sam3")
49
 
50
+ print(f"โœ… All Models loaded successfully on {device}!")
51
+ return True
52
+ except Exception as e:
53
+ print(f"โŒ Model loading failed: {e}")
54
+ return False
55
 
56
  # ============ LAYER MANAGEMENT ============
57
  class LayerManager:
 
253
 
254
  # ============ UI FUNCTIONS ============
255
  def update_layer_selector_choices(manager):
256
+ """๋ ˆ์ด์–ด ์„ ํƒ ๋ผ๋””์˜ค ๋ฒ„ํŠผ์˜ choices ์—…๋ฐ์ดํŠธ"""
257
+ choices = [layer['name'] for layer in manager.layers.values()]
258
+ current_value = None
259
+ if manager.current_layer_id and manager.current_layer_id in manager.layers:
260
+ current_value = manager.layers[manager.current_layer_id]['name']
261
+ return gr.Radio(choices=choices, interactive=True, value=current_value)
262
 
263
  def create_new_layer(name, current_manager):
264
  """์ƒˆ ๋ ˆ์ด์–ด ์ƒ์„ฑ"""
 
392
  print(f"\n[segment_all_layers] Processing layer: {layer_name}")
393
  print(f"[segment_all_layers] Points: {len(layer['points'])}, Labels: {layer['point_labels']}")
394
 
395
+ # Load models if needed
396
+ if not load_models():
397
+ print(f"[segment_all_layers] Failed to load models for layer: {layer_name}")
398
+ continue
399
+
400
  # SAM3 Tracker๋กœ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜
401
  points_list = layer['points']
402
  labels_list = layer['point_labels']
 
404
  input_points = [[points_list]]
405
  input_labels = [[labels_list]]
406
 
407
+ # Use the same device as the model
408
+ model_device = next(TRK_MODEL.parameters()).device
409
+ inputs = TRK_PROCESSOR(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model_device)
410
 
411
  with torch.no_grad():
412
  outputs = TRK_MODEL(**inputs, multimask_output=False)
 
500
  gr.Markdown("### Layers Status")
501
  layer_buttons_html = gr.HTML("<div style='padding: 10px; text-align: center; color: #888;'>No layers created</div>")
502
 
503
+ # ๋ ˆ์ด์–ด ์„ ํƒ (๋ผ๋””์˜ค ๋ฒ„ํŠผ์œผ๋กœ ๋ณ€๊ฒฝ)
504
+ layer_selector = gr.Radio(label="Select Layer to Add Points", choices=[], interactive=True)
505
 
506
  # ํฌ์ธํŠธ ๋ชจ๋“œ ์„ ํƒ
507
  gr.Markdown("### Point Mode")
 
552
  )
553
 
554
  # ๋ ˆ์ด์–ด ์„ ํƒ
555
+ def on_layer_select(selected_name, mgr):
556
  if mgr is None:
557
  mgr = LayerManager()
558
 
559
+ if selected_name:
560
+ # ์ด๋ฆ„์œผ๋กœ layer_id ์ฐพ๊ธฐ
561
+ layer_id = None
562
+ for lid, layer in mgr.layers.items():
563
+ if layer['name'] == selected_name:
564
+ layer_id = lid
565
+ break
566
+
567
+ if layer_id:
568
+ mgr.set_current_layer(layer_id)
569
+ return mgr, create_layer_status_html(mgr), f"Layer '{selected_name}' selected"
570
  return mgr, create_layer_status_html(mgr), "Please select a layer"
571
 
572
  layer_selector.change(