Spaces:
Running
Running
| import torch | |
| from model.BiSeNet.build_bisenet import BiSeNet | |
| from model.BiSeNetV2.model import BiSeNetV2 | |
| # general loading function | |
| def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet: | |
| """ | |
| Load the specified model and move it to the given device. | |
| Args: | |
| model (str): model to be loaded. | |
| device (str): Device to load the model onto ('cpu' or 'cuda'). | |
| Returns: | |
| model (BiSeNet): The loaded BiSeNet model. | |
| """ | |
| match model.lower() if isinstance(model, str) else model: | |
| case 'bisenet': model = loadBiSeNet(device) | |
| case 'bisenetv2': model = loadBiSeNetV2(device) | |
| case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .") | |
| return model | |
| # BiSeNet model loading function | |
| def loadBiSeNet(device: str = 'cpu') -> BiSeNet: | |
| """ | |
| Load the BiSeNet model and move it to the specified device. | |
| Args: | |
| device (str): Device to load the model onto ('cpu' or 'cuda'). | |
| Returns: | |
| model (BiSeNet): The loaded BiSeNet model. | |
| """ | |
| model = BiSeNet(num_classes=19, context_path='resnet18').to(device) | |
| model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict']) | |
| model.eval() | |
| return model | |
| def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2: | |
| """ | |
| Load the BiSeNetV2 model and move it to the specified device. | |
| Args: | |
| device (str): Device to load the model onto ('cpu' or 'cuda'). | |
| Returns: | |
| model (BiSeNetV2): The loaded BiSeNetV2 model. | |
| """ | |
| model = BiSeNetV2(n_classes=19).to(device) | |
| model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict']) | |
| model.eval() | |
| return model |