JasonYinnnn commited on
Commit
0fc1efb
·
1 Parent(s): c4d999e

test SAM 1

Browse files
Files changed (1) hide show
  1. app.py +8 -21
app.py CHANGED
@@ -28,7 +28,8 @@ from threeDFixer.datasets.utils import (
28
  project2ply
29
  )
30
  from threeDFixer.utils import render_utils, postprocessing_utils
31
- from scripts.grounding_sam2 import plot_segmentation, segment
 
32
  from sam2.build_sam import build_sam2
33
  from sam2.sam2_image_predictor import SAM2ImagePredictor
34
  import copy
@@ -55,18 +56,11 @@ generated_object_map = {}
55
 
56
  # Prepare models
57
  ## Grounding SAM
58
- if not os.path.exists("./checkpoints/sam2.1_hiera_large.pt"):
59
- os.makedirs("./checkpoints/", exist_ok=True)
60
- os.system("wget -O ./checkpoints/sam2.1_hiera_large.pt https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt")
61
- SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
62
- SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
63
-
64
- sam2_model = build_sam2(
65
- config_file=SAM2_CONFIG,
66
- ckpt_path=SAM2_CHECKPOINT,
67
- device=DEVICE
68
  )
69
- sam2_predictor = SAM2ImagePredictor(sam2_model)
70
 
71
  ############## 3D-Fixer model
72
  model_dir = 'HorizonRobotics/3D-Fixer'
@@ -192,14 +186,6 @@ def run_segmentation(
192
  rgb_image = image_prompts["image"].convert("RGB")
193
 
194
  global work_space
195
- # global sam2_predictor
196
-
197
- # if sam2_predictor is None:
198
- # sam2_model = build_sam2(
199
- # config_file=SAM2_CONFIG,
200
- # ckpt_path=SAM2_CHECKPOINT,
201
- # )
202
- # sam2_predictor = SAM2ImagePredictor(sam2_model)
203
 
204
  # pre-process the layers and get the xyxy boxes of each layer
205
  if len(image_prompts["points"]) == 0:
@@ -214,7 +200,8 @@ def run_segmentation(
214
  ]
215
 
216
  detections = segment(
217
- sam2_predictor,
 
218
  rgb_image,
219
  boxes=[boxes],
220
  polygon_refinement=polygon_refinement,
 
28
  project2ply
29
  )
30
  from threeDFixer.utils import render_utils, postprocessing_utils
31
+ from transformers import AutoModelForMaskGeneration, AutoProcessor
32
+ from scripts.grounding_sam import plot_segmentation, segment
33
  from sam2.build_sam import build_sam2
34
  from sam2.sam2_image_predictor import SAM2ImagePredictor
35
  import copy
 
56
 
57
  # Prepare models
58
  ## Grounding SAM
59
+ segmenter_id = "facebook/sam-vit-base"
60
+ sam_processor = AutoProcessor.from_pretrained(segmenter_id)
61
+ sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
62
+ DEVICE, DTYPE
 
 
 
 
 
 
63
  )
 
64
 
65
  ############## 3D-Fixer model
66
  model_dir = 'HorizonRobotics/3D-Fixer'
 
186
  rgb_image = image_prompts["image"].convert("RGB")
187
 
188
  global work_space
 
 
 
 
 
 
 
 
189
 
190
  # pre-process the layers and get the xyxy boxes of each layer
191
  if len(image_prompts["points"]) == 0:
 
200
  ]
201
 
202
  detections = segment(
203
+ sam_processor,
204
+ sam_segmentator,
205
  rgb_image,
206
  boxes=[boxes],
207
  polygon_refinement=polygon_refinement,