File size: 1,516 Bytes
24f3fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import clip
from .config import *
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import groundingdino.datasets.transforms as T
from groundingdino.util.inference import load_model as load_dino_model

_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_sam = None
_clip_model = None
_clip_preprocess = None
_grounding_dino = None
_dino_large = None


def get_device():
    return _device

def get_sam_predictor():
    global _sam
    if _sam is None:
        sam = build_sam2(SAM_CFG, SAM_CHECKPT)
        _sam = SAM2ImagePredictor(sam)
    return _sam

def get_clip():
    global _clip_model, _clip_preprocess
    if _clip_model is None:
        model, preprocess_fn = clip.load(CLIP_MODEL, device=_device)
        _clip_model = model.eval().to(_device)
        _clip_preprocess = preprocess_fn
    return _clip_model, _clip_preprocess

def get_groundingdino_model():
    global _grounding_dino
    if _grounding_dino is None:
        grounding_model = load_dino_model(DINO_CFG, DINO_CHECKPT, device="cuda")
        _grounding_dino = grounding_model
    
    return _grounding_dino

def get_dinov2_large():
    """Load DINOv2 large model"""
    global _dino_large

    if _dino_large is None:
        _dino_large = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
        _dino_large.eval()
        device = get_device()
    
        # Move model to device
        _dino_large = _dino_large.to(device)

    return _dino_large