JasonYinnnn commited on
Commit
19506f1
·
1 Parent(s): c0c4541

lazy load seg; dynamic duration

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -93,7 +93,7 @@ generated_object_map = {}
93
 
94
  # Prepare models
95
  ## Grounding SAM
96
- sam2_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
97
 
98
  ############## 3D-Fixer model
99
  model_dir = 'HorizonRobotics/3D-Fixer'
@@ -217,6 +217,10 @@ def run_segmentation(
217
  rgb_image = image_prompts["image"].convert("RGB")
218
 
219
  global work_space
 
 
 
 
220
 
221
  # pre-process the layers and get the xyxy boxes of each layer
222
  if len(image_prompts["points"]) == 0:
@@ -399,8 +403,15 @@ def export_scene_glb(trimeshes, work_space, scene_name):
399
 
400
  return scene_path
401
 
 
 
 
 
 
 
 
402
  @torch.no_grad()
403
- @spaces.GPU(duration=600)
404
  def run_generation(
405
  rgb_image: Any,
406
  seg_image: Union[str, Image.Image],
 
93
 
94
  # Prepare models
95
  ## Grounding SAM
96
+ sam2_predictor = None # SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
97
 
98
  ############## 3D-Fixer model
99
  model_dir = 'HorizonRobotics/3D-Fixer'
 
217
  rgb_image = image_prompts["image"].convert("RGB")
218
 
219
  global work_space
220
+ global sam2_predictor
221
+ if sam2_predictor is None:
222
+ # lazy initialization
223
+ sam2_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
224
 
225
  # pre-process the layers and get the xyxy boxes of each layer
226
  if len(image_prompts["points"]) == 0:
 
403
 
404
  return scene_path
405
 
406
+ def get_duration(rgb_image, seg_image, seed, randomize_seed,
407
+ num_inference_steps, guidance_scale, cfg_interval_start,
408
+ cfg_interval_end, t_rescale):
409
+ instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
410
+ step_duration = 15.0
411
+ return instance_labels.shape[0] * step_duration
412
+
413
  @torch.no_grad()
414
+ @spaces.GPU(duration=get_duration)
415
  def run_generation(
416
  rgb_image: Any,
417
  seg_image: Union[str, Image.Image],