Spaces:
Running
on
Zero
Running
on
Zero
Update eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
Browse files
eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
CHANGED
|
@@ -60,7 +60,7 @@ class FlorenceSAM:
|
|
| 60 |
self.torch_dtype = torch.bfloat16
|
| 61 |
|
| 62 |
FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
|
| 63 |
-
SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH')
|
| 64 |
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 65 |
|
| 66 |
self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
|
|
@@ -127,7 +127,7 @@ class FlorenceSAM:
|
|
| 127 |
|
| 128 |
def segmentation(self, image, input_boxes, seg_model="sam"):
|
| 129 |
if seg_model == "sam":
|
| 130 |
-
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.
|
| 131 |
sam2_predictor = self.sam2_predictor
|
| 132 |
sam2_predictor.set_image(np.array(image))
|
| 133 |
masks, scores, logits = sam2_predictor.predict(
|
|
|
|
| 60 |
self.torch_dtype = torch.bfloat16
|
| 61 |
|
| 62 |
FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
|
| 63 |
+
SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large")
|
| 64 |
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 65 |
|
| 66 |
self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
|
|
|
|
| 127 |
|
| 128 |
def segmentation(self, image, input_boxes, seg_model="sam"):
|
| 129 |
if seg_model == "sam":
|
| 130 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 131 |
sam2_predictor = self.sam2_predictor
|
| 132 |
sam2_predictor.set_image(np.array(image))
|
| 133 |
masks, scores, logits = sam2_predictor.predict(
|