Spaces:
Build error
Build error
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| import os | |
| import torch | |
| from modules.mask_utils import * | |
| from modules.model_downloader import * | |
| class SamInference: | |
| def __init__(self): | |
| self.model = None | |
| self.model_path = f"models/sam_vit_h_4b8939.pth" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.mask_generator = None | |
| # Tuable Parameters , All default values | |
| self.tunable_params = { | |
| 'points_per_side': 32, | |
| 'pred_iou_thresh': 0.88, | |
| 'stability_score_thresh': 0.95, | |
| 'crop_n_layers': 0, | |
| 'crop_n_points_downscale_factor': 1, | |
| 'min_mask_region_area': 0 | |
| } | |
| def set_mask_generator(self): | |
| print("applying configs to model..") | |
| if not os.path.exists(self.model_path): | |
| print("No needed SAM model detected. downloading VIT H SAM model....") | |
| download_sam_model_url() | |
| self.model = sam_model_registry["default"](checkpoint=self.model_path) | |
| self.model.to(device=self.device) | |
| self.mask_generator = SamAutomaticMaskGenerator( | |
| self.model, | |
| points_per_side=self.tunable_params['points_per_side'], | |
| pred_iou_thresh=self.tunable_params['pred_iou_thresh'], | |
| stability_score_thresh=self.tunable_params['stability_score_thresh'], | |
| crop_n_layers=self.tunable_params['crop_n_layers'], | |
| crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'], | |
| min_mask_region_area=self.tunable_params['min_mask_region_area'], | |
| output_mode="coco_rle", | |
| ) | |
| def generate_mask(self, image): | |
| return [self.mask_generator.generate(image)] | |
| def generate_mask_app(self, image, *params): | |
| tunable_params = { | |
| 'points_per_side': int(params[0]), | |
| 'pred_iou_thresh': float(params[1]), | |
| 'stability_score_thresh': float(params[2]), | |
| 'crop_n_layers': int(params[3]), | |
| 'crop_n_points_downscale_factor': int(params[4]), | |
| 'min_mask_region_area': int(params[5]), | |
| } | |
| try: | |
| if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params: | |
| self.tunable_params = tunable_params | |
| self.set_mask_generator() | |
| masks = self.mask_generator.generate(image) | |
| combined_image = create_mask_combined_images(image, masks) | |
| gallery = create_mask_gallery(image, masks) | |
| return [combined_image] + gallery | |
| except Exception as e: | |
| print(e) |