| | import os |
| | import torch |
| | import torchvision |
| | from ultralytics import YOLO |
| |
|
| |
|
| |
|
| | def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None): |
| | """ |
| | @param[in] nclasses |
| | @param[in] mode set mode for frame classification or uninformative part mask |
| | """ |
| | if mode == 'classify': |
| | |
| | net = torchvision.models.resnet18(num_classes = nclasses) |
| | net.cuda() |
| | if mode == 'mask': |
| | net = YOLO(segment_model) |
| |
|
| | return net |
| |
|
| |
|
| |
|
| | net = build_model(nclasses=num_classes, mode='classify') |
| | model_path = 'Video storyboard classification models' |
| |
|
| | |
| | net = torch.nn.DataParallel(net) |
| | torch.backends.cudnn.benchmark = True |
| | state = torch.load(model_path, map_location=torch.device('cuda')) |
| | net.load_state_dict(state['net']) |
| | net.eval() |
| |
|