from typing import Dict, Tuple import torch from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from dounseen.core import UnseenClassifier CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"] CHECKPOINTS = { "tiny": ["sam2_hiera_t.yaml", "models/sam2/sam2_hiera_tiny.pt"], "small": ["sam2_hiera_s.yaml", "models/sam2/sam2_hiera_small.pt"], "base_plus": ["sam2_hiera_b+.yaml", "models/sam2/sam2_hiera_base_plus.pt"], "large": ["sam2_hiera_l.yaml", "models/sam2/sam2_hiera_large.pt"], } def load_sam2_models( device: torch.device, ) -> Dict[str, SAM2ImagePredictor]: models = {} for key, (config, checkpoint) in CHECKPOINTS.items(): model = build_sam2(config, checkpoint, device=device) models[key] = model return models def make_sam2_mask_generators( models: Dict[str, SAM2ImagePredictor], point_per_side: int = 10, ) -> Dict[str, SAM2AutomaticMaskGenerator]: mask_generators={} for key in CHECKPOINT_NAMES: model = models[key] mask_generators[key] = SAM2AutomaticMaskGenerator( model=model, points_per_side=point_per_side, points_per_batch=10, pred_iou_thresh=0.7, stability_score_thresh=0.92, stability_score_offset=0.7, crop_n_layers=1, box_nms_thresh=0.7, ) return mask_generators def load_dounseen_model( device: torch.device, ) -> UnseenClassifier: unseen_classifier = UnseenClassifier( classification_model_path="models/dounseen/vit_b_16_epoch_199_augment.pth", gallery_images=None, gallery_buffered_path=None, augment_gallery=False, batch_size=100, ) return unseen_classifier