| import os |
| import torch |
| import torchvision |
| from ultralytics import YOLO |
|
|
|
|
|
|
| def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None): |
| """ |
| @param[in] nclasses |
| @param[in] mode set mode for frame classification or uninformative part mask |
| """ |
| if mode == 'classify': |
| |
| net = torchvision.models.resnet18(num_classes = nclasses) |
| net.cuda() |
| if mode == 'mask': |
| net = YOLO(segment_model) |
|
|
| return net |
|
|
|
|
|
|
| net = build_model(nclasses=num_classes, mode='classify') |
| model_path = 'Video storyboard classification models' |
|
|
| |
| net = torch.nn.DataParallel(net) |
| torch.backends.cudnn.benchmark = True |
| state = torch.load(model_path, map_location=torch.device('cuda')) |
| net.load_state_dict(state['net']) |
| net.eval() |
|
|