JasonYinnnn commited on
Commit
e0c665a
·
1 Parent(s): 3f835d3

try xformers

Browse files
Files changed (2) hide show
  1. app.py +43 -57
  2. requirements.txt +2 -1
app.py CHANGED
@@ -20,16 +20,16 @@ import trimesh
20
  import random
21
  import imageio
22
  from einops import repeat
23
- # from threeDFixer.pipelines import ThreeDFixerPipeline
24
- # from threeDFixer.datasets.utils import (
25
- # edge_mask_morph_gradient,
26
- # process_scene_image,
27
- # process_instance_image,
28
- # transform_vertices,
29
- # normalize_vertices,
30
- # project2ply
31
- # )
32
- # from threeDFixer.utils import render_utils, postprocessing_utils
33
  from transformers import AutoModelForMaskGeneration, AutoProcessor
34
  from scripts.grounding_sam import plot_segmentation, segment
35
  import copy
@@ -176,57 +176,43 @@ def cleanup_tmp(tmp_root: str = "./tmp", expire_seconds: int = 3600) -> None:
176
  except Exception as e:
177
  print(f"[cleanup_tmp] failed to remove {path}: {e}")
178
 
179
-
180
  def run_segmentation(
181
  image_prompts: Any,
182
  polygon_refinement: bool = True,
183
  ) -> Image.Image:
184
- try:
185
- gr.Info('in run_segmentation')
186
- rgb_image = image_prompts["image"].convert("RGB")
187
- gr.Info('in run_segmentation, read image')
188
 
189
- global work_space
190
 
191
- # device = "cuda" if torch.cuda.is_available() else "cpu"
192
- device = "cpu"
193
- sam_segmentator.to(device=device, dtype=DTYPE if device == 'cuda' else torch.float32)
194
- gr.Info('in run_segmentation, move sam')
195
 
196
- # pre-process the layers and get the xyxy boxes of each layer
197
- if len(image_prompts["points"]) == 0:
198
- raise gr.Error("No points provided for segmentation. Please add points to the image.")
199
-
200
- boxes = [
201
- [
202
- [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
203
- for box in image_prompts["points"]
204
- ]
205
  ]
 
206
 
207
- gr.Info('in run_segmentation, run info')
208
- with torch.no_grad():
209
- detections = segment(
210
- sam_processor,
211
- sam_segmentator,
212
- rgb_image,
213
- boxes=[boxes],
214
- polygon_refinement=polygon_refinement,
215
- )
216
- seg_map_pil = plot_segmentation(rgb_image, detections)
217
-
218
- torch.cuda.empty_cache()
219
-
220
- cleanup_tmp(TMP_DIR, expire_seconds=3600)
221
-
222
- work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
223
- os.makedirs(work_space, exist_ok=True)
224
- seg_map_pil.save(os.path.join(work_space, 'mask.png'))
225
-
226
- except Exception as e:
227
- import traceback
228
- traceback.print_exc()
229
- raise gr.Error(f"run_segmentation failed: {type(e).__name__}: {e}")
230
 
231
  return seg_map_pil
232
 
@@ -901,15 +887,15 @@ if __name__ == '__main__':
901
  segmenter_id = "facebook/sam-vit-base"
902
  sam_processor = AutoProcessor.from_pretrained(segmenter_id)
903
  sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
904
- DEVICE, DTYPE
905
  )
906
 
907
  ############## 3D-Fixer model
908
  model_dir = 'HorizonRobotics/3D-Fixer'
909
- # pipeline = ThreeDFixerPipeline.from_pretrained(
910
- # model_dir, compile=False
911
- # )
912
- # pipeline.to(device=DEVICE)
913
  ############## 3D-Fixer model
914
 
915
  rot = np.array([
 
20
  import random
21
  import imageio
22
  from einops import repeat
23
+ from threeDFixer.pipelines import ThreeDFixerPipeline
24
+ from threeDFixer.datasets.utils import (
25
+ edge_mask_morph_gradient,
26
+ process_scene_image,
27
+ process_instance_image,
28
+ transform_vertices,
29
+ normalize_vertices,
30
+ project2ply
31
+ )
32
+ from threeDFixer.utils import render_utils, postprocessing_utils
33
  from transformers import AutoModelForMaskGeneration, AutoProcessor
34
  from scripts.grounding_sam import plot_segmentation, segment
35
  import copy
 
176
  except Exception as e:
177
  print(f"[cleanup_tmp] failed to remove {path}: {e}")
178
 
179
+ # run seg on CPU
180
  def run_segmentation(
181
  image_prompts: Any,
182
  polygon_refinement: bool = True,
183
  ) -> Image.Image:
184
+ rgb_image = image_prompts["image"].convert("RGB")
 
 
 
185
 
186
+ global work_space
187
 
188
+ device = "cpu"
189
+ sam_segmentator.to(device=device, dtype=DTYPE if device == 'cuda' else torch.float32)
 
 
190
 
191
+ # pre-process the layers and get the xyxy boxes of each layer
192
+ if len(image_prompts["points"]) == 0:
193
+ raise gr.Error("No points provided for segmentation. Please add points to the image.")
194
+
195
+ boxes = [
196
+ [
197
+ [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
198
+ for box in image_prompts["points"]
 
199
  ]
200
+ ]
201
 
202
+ with torch.no_grad():
203
+ detections = segment(
204
+ sam_processor,
205
+ sam_segmentator,
206
+ rgb_image,
207
+ boxes=[boxes],
208
+ polygon_refinement=polygon_refinement,
209
+ )
210
+ seg_map_pil = plot_segmentation(rgb_image, detections)
211
+
212
+ cleanup_tmp(TMP_DIR, expire_seconds=3600)
213
+ work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
214
+ os.makedirs(work_space, exist_ok=True)
215
+ seg_map_pil.save(os.path.join(work_space, 'mask.png'))
 
 
 
 
 
 
 
 
 
216
 
217
  return seg_map_pil
218
 
 
887
  segmenter_id = "facebook/sam-vit-base"
888
  sam_processor = AutoProcessor.from_pretrained(segmenter_id)
889
  sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
890
+ DEVICE, DTYPE if DEVICE == 'cuda' else torch.float32
891
  )
892
 
893
  ############## 3D-Fixer model
894
  model_dir = 'HorizonRobotics/3D-Fixer'
895
+ pipeline = ThreeDFixerPipeline.from_pretrained(
896
+ model_dir, compile=False
897
+ )
898
+ pipeline.to(device=DEVICE)
899
  ############## 3D-Fixer model
900
 
901
  rot = np.array([
requirements.txt CHANGED
@@ -43,4 +43,5 @@ pydantic==2.10.6
43
  # httpx==0.27.0
44
  kaolin==0.18.0
45
  flash-attn==2.8.3+pt2.8.0cu129
46
- nvdiffrast==0.4.0+253ac4fpt2.8.0cu129
 
 
43
  # httpx==0.27.0
44
  kaolin==0.18.0
45
  flash-attn==2.8.3+pt2.8.0cu129
46
+ nvdiffrast==0.4.0+253ac4fpt2.8.0cu129
47
+ xformers