File size: 1,313 Bytes
4edb0a5
 
 
 
 
 
 
 
 
 
 
 
 
5af56ca
 
 
4edb0a5
5af56ca
 
 
 
4edb0a5
5af56ca
 
4edb0a5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os

from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot
from annotator.uniformer.mmseg.core.evaluation import get_palette
from annotator.util import annotator_ckpts_path


checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth"


class UniformerDetector:
    def __init__(self):
        modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth")
        if not os.path.exists(annotator_ckpts_path):
            os.makedirs(annotator_ckpts_path)
            
        if not os.path.exists(modelpath):
            from torch.hub import download_url_to_file
            print(f"Downloading upernet_global_small from {checkpoint_file}...")
            download_url_to_file(checkpoint_file, modelpath)
            
        config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = init_segmentor(config_file, modelpath, device=device)

    def __call__(self, img):
        result = inference_segmentor(self.model, img)
        res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
        return res_img