Spaces:
Running
Running
File size: 1,827 Bytes
5a17bb3 60fd570 5a17bb3 60fd570 5a17bb3 60fd570 5a17bb3 c06c582 5a17bb3 60fd570 5a17bb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import torch
from model.BiSeNet.build_bisenet import BiSeNet
from model.BiSeNetV2.model import BiSeNetV2
# general loading function
def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
"""
Load the specified model and move it to the given device.
Args:
model (str): model to be loaded.
device (str): Device to load the model onto ('cpu' or 'cuda').
Returns:
model (BiSeNet): The loaded BiSeNet model.
"""
match model.lower() if isinstance(model, str) else model:
case 'bisenet': model = loadBiSeNet(device)
case 'bisenetv2': model = loadBiSeNetV2(device)
case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
return model
# BiSeNet model loading function
def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
"""
Load the BiSeNet model and move it to the specified device.
Args:
device (str): Device to load the model onto ('cpu' or 'cuda').
Returns:
model (BiSeNet): The loaded BiSeNet model.
"""
model = BiSeNet(num_classes=19, context_path='resnet18').to(device)
model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict'])
model.eval()
return model
def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2:
"""
Load the BiSeNetV2 model and move it to the specified device.
Args:
device (str): Device to load the model onto ('cpu' or 'cuda').
Returns:
model (BiSeNetV2): The loaded BiSeNetV2 model.
"""
model = BiSeNetV2(n_classes=19).to(device)
model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict'])
model.eval()
return model |