Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, Tuple | |
| import torch | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from dounseen.core import UnseenClassifier | |
| CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"] | |
| CHECKPOINTS = { | |
| "tiny": ["sam2_hiera_t.yaml", "models/sam2/sam2_hiera_tiny.pt"], | |
| "small": ["sam2_hiera_s.yaml", "models/sam2/sam2_hiera_small.pt"], | |
| "base_plus": ["sam2_hiera_b+.yaml", "models/sam2/sam2_hiera_base_plus.pt"], | |
| "large": ["sam2_hiera_l.yaml", "models/sam2/sam2_hiera_large.pt"], | |
| } | |
| def load_sam2_models( | |
| device: torch.device, | |
| ) -> Dict[str, SAM2ImagePredictor]: | |
| models = {} | |
| for key, (config, checkpoint) in CHECKPOINTS.items(): | |
| model = build_sam2(config, checkpoint, device=device) | |
| models[key] = model | |
| return models | |
| def make_sam2_mask_generators( | |
| models: Dict[str, SAM2ImagePredictor], | |
| point_per_side: int = 10, | |
| ) -> Dict[str, SAM2AutomaticMaskGenerator]: | |
| mask_generators={} | |
| for key in CHECKPOINT_NAMES: | |
| model = models[key] | |
| mask_generators[key] = SAM2AutomaticMaskGenerator( | |
| model=model, | |
| points_per_side=point_per_side, | |
| points_per_batch=10, | |
| pred_iou_thresh=0.7, | |
| stability_score_thresh=0.92, | |
| stability_score_offset=0.7, | |
| crop_n_layers=1, | |
| box_nms_thresh=0.7, | |
| ) | |
| return mask_generators | |
| def load_dounseen_model( | |
| device: torch.device, | |
| ) -> UnseenClassifier: | |
| unseen_classifier = UnseenClassifier( | |
| classification_model_path="models/dounseen/vit_b_16_epoch_199_augment.pth", | |
| gallery_images=None, | |
| gallery_buffered_path=None, | |
| augment_gallery=False, | |
| batch_size=100, | |
| ) | |
| return unseen_classifier | |