object-memory / core /models.py
russ4stall
fresh history
24f3fb6
import torch
import clip
from .config import *
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import groundingdino.datasets.transforms as T
from groundingdino.util.inference import load_model as load_dino_model
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_sam = None
_clip_model = None
_clip_preprocess = None
_grounding_dino = None
_dino_large = None
def get_device():
return _device
def get_sam_predictor():
global _sam
if _sam is None:
sam = build_sam2(SAM_CFG, SAM_CHECKPT)
_sam = SAM2ImagePredictor(sam)
return _sam
def get_clip():
global _clip_model, _clip_preprocess
if _clip_model is None:
model, preprocess_fn = clip.load(CLIP_MODEL, device=_device)
_clip_model = model.eval().to(_device)
_clip_preprocess = preprocess_fn
return _clip_model, _clip_preprocess
def get_groundingdino_model():
global _grounding_dino
if _grounding_dino is None:
grounding_model = load_dino_model(DINO_CFG, DINO_CHECKPT, device="cuda")
_grounding_dino = grounding_model
return _grounding_dino
def get_dinov2_large():
"""Load DINOv2 large model"""
global _dino_large
if _dino_large is None:
_dino_large = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
_dino_large.eval()
device = get_device()
# Move model to device
_dino_large = _dino_large.to(device)
return _dino_large