dounseen / utils /models.py
anas-gouda's picture
add point_per_side to params
abce719
from typing import Dict, Tuple
import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from dounseen.core import UnseenClassifier
CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
CHECKPOINTS = {
"tiny": ["sam2_hiera_t.yaml", "models/sam2/sam2_hiera_tiny.pt"],
"small": ["sam2_hiera_s.yaml", "models/sam2/sam2_hiera_small.pt"],
"base_plus": ["sam2_hiera_b+.yaml", "models/sam2/sam2_hiera_base_plus.pt"],
"large": ["sam2_hiera_l.yaml", "models/sam2/sam2_hiera_large.pt"],
}
def load_sam2_models(
device: torch.device,
) -> Dict[str, SAM2ImagePredictor]:
models = {}
for key, (config, checkpoint) in CHECKPOINTS.items():
model = build_sam2(config, checkpoint, device=device)
models[key] = model
return models
def make_sam2_mask_generators(
models: Dict[str, SAM2ImagePredictor],
point_per_side: int = 10,
) -> Dict[str, SAM2AutomaticMaskGenerator]:
mask_generators={}
for key in CHECKPOINT_NAMES:
model = models[key]
mask_generators[key] = SAM2AutomaticMaskGenerator(
model=model,
points_per_side=point_per_side,
points_per_batch=10,
pred_iou_thresh=0.7,
stability_score_thresh=0.92,
stability_score_offset=0.7,
crop_n_layers=1,
box_nms_thresh=0.7,
)
return mask_generators
def load_dounseen_model(
device: torch.device,
) -> UnseenClassifier:
unseen_classifier = UnseenClassifier(
classification_model_path="models/dounseen/vit_b_16_epoch_199_augment.pth",
gallery_images=None,
gallery_buffered_path=None,
augment_gallery=False,
batch_size=100,
)
return unseen_classifier