import torch import clip from .config import * from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor import groundingdino.datasets.transforms as T from groundingdino.util.inference import load_model as load_dino_model _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _sam = None _clip_model = None _clip_preprocess = None _grounding_dino = None _dino_large = None def get_device(): return _device def get_sam_predictor(): global _sam if _sam is None: sam = build_sam2(SAM_CFG, SAM_CHECKPT) _sam = SAM2ImagePredictor(sam) return _sam def get_clip(): global _clip_model, _clip_preprocess if _clip_model is None: model, preprocess_fn = clip.load(CLIP_MODEL, device=_device) _clip_model = model.eval().to(_device) _clip_preprocess = preprocess_fn return _clip_model, _clip_preprocess def get_groundingdino_model(): global _grounding_dino if _grounding_dino is None: grounding_model = load_dino_model(DINO_CFG, DINO_CHECKPT, device="cuda") _grounding_dino = grounding_model return _grounding_dino def get_dinov2_large(): """Load DINOv2 large model""" global _dino_large if _dino_large is None: _dino_large = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') _dino_large.eval() device = get_device() # Move model to device _dino_large = _dino_large.to(device) return _dino_large