File size: 1,827 Bytes
5a17bb3
 
 
60fd570
5a17bb3
60fd570
5a17bb3
 
 
 
 
 
 
 
 
 
 
 
 
60fd570
5a17bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c06c582
5a17bb3
 
60fd570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a17bb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch

from model.BiSeNet.build_bisenet import BiSeNet
from model.BiSeNetV2.model import BiSeNetV2

# general loading function
def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
    """
    Load the specified model and move it to the given device.
    
    Args:
        model (str): model to be loaded.
        device (str): Device to load the model onto ('cpu' or 'cuda').
        
    Returns:
        model (BiSeNet): The loaded BiSeNet model.
    """
    match model.lower() if isinstance(model, str) else model:
        case 'bisenet': model = loadBiSeNet(device)
        case 'bisenetv2': model = loadBiSeNetV2(device)
        case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
    
    return model


# BiSeNet model loading function
def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
    """
    Load the BiSeNet model and move it to the specified device.
    
    Args:
        device (str): Device to load the model onto ('cpu' or 'cuda').
        
    Returns:
        model (BiSeNet): The loaded BiSeNet model.
    """
    model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
    model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict'])
    model.eval()
    
    return model


def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2:
    """
    Load the BiSeNetV2 model and move it to the specified device.
    
    Args:
        device (str): Device to load the model onto ('cpu' or 'cuda').
        
    Returns:
        model (BiSeNetV2): The loaded BiSeNetV2 model.
    """
    model = BiSeNetV2(n_classes=19).to(device)
    model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict'])
    model.eval()
    
    return model