Spaces:
Running
Running
| import numpy as np | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| class Predictor: | |
| def __init__(self, model_cfg, checkpoint, device): | |
| self.device = device | |
| self.model = build_sam2(model_cfg, checkpoint, device=device) | |
| self.predictor = SAM2ImagePredictor(self.model) | |
| self.image_set = False | |
| def set_image(self, image): | |
| """Set the image for SAM prediction.""" | |
| self.image = image | |
| self.predictor.set_image(image) | |
| self.image_set = True | |
| def predict(self, point_coords, point_labels, multimask_output=False): | |
| """Run SAM prediction.""" | |
| if not self.image_set: | |
| raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") | |
| return self.predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| multimask_output=multimask_output | |
| ) | |