import torch from model.BiSeNet.build_bisenet import BiSeNet from model.BiSeNetV2.model import BiSeNetV2 # BiSeNet model loading function def loadBiSeNet(device: str = 'cpu') -> BiSeNet: """ Load the BiSeNet model and move it to the specified device. Args: device (str): Device to load the model onto ('cpu' or 'cuda'). Returns: model (BiSeNet): The loaded BiSeNet model. """ model = BiSeNet(num_classes=19, context_path='resnet18').to(device) model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device)['model_state_dict']) model.eval() return model def loadBiSeNetV2(device: str = 'cpu') -> BiSeNetV2: """ Load the BiSeNetV2 model and move it to the specified device. Args: device (str): Device to load the model onto ('cpu' or 'cuda'). Returns: model (BiSeNetV2): The loaded BiSeNetV2 model. """ model = BiSeNetV2(n_classes=19).to(device) model.load_state_dict(torch.load('./weights/BiSeNetV2/weightADV.pth', map_location=device)['model_state_dict']) model.eval() return model