| import random |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
|
|
| |
| |
| |
| |
| def cvtColor(image): |
| if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
| return image |
| else: |
| image = image.convert('RGB') |
| return image |
|
|
| |
| |
| |
| def resize_image(image, size, letterbox_image): |
| iw, ih = image.size |
| w, h = size |
| if letterbox_image: |
| scale = min(w/iw, h/ih) |
| nw = int(iw*scale) |
| nh = int(ih*scale) |
|
|
| image = image.resize((nw,nh), Image.BICUBIC) |
| new_image = Image.new('RGB', size, (128,128,128)) |
| new_image.paste(image, ((w-nw)//2, (h-nh)//2)) |
| else: |
| new_image = image.resize((w, h), Image.BICUBIC) |
| return new_image |
|
|
| |
| |
| |
| def get_classes(classes_path): |
| with open(classes_path, encoding='utf-8') as f: |
| class_names = f.readlines() |
| class_names = [c.strip() for c in class_names] |
| return class_names, len(class_names) |
|
|
| |
| |
| |
| def preprocess_input(inputs): |
| MEANS = (104, 117, 123) |
| return inputs - MEANS |
|
|
| |
| |
| |
| def get_lr(optimizer): |
| for param_group in optimizer.param_groups: |
| return param_group['lr'] |
|
|
| |
| |
| |
| def seed_everything(seed=11): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| |
| |
| def worker_init_fn(worker_id, rank, seed): |
| worker_seed = rank + seed |
| random.seed(worker_seed) |
| np.random.seed(worker_seed) |
| torch.manual_seed(worker_seed) |
|
|
| def show_config(**kwargs): |
| print('Configurations:') |
| print('-' * 70) |
| print('|%25s | %40s|' % ('keys', 'values')) |
| print('-' * 70) |
| for key, value in kwargs.items(): |
| print('|%25s | %40s|' % (str(key), str(value))) |
| print('-' * 70) |
|
|
| def download_weights(backbone, model_dir="./model_data"): |
| import os |
| from torch.hub import load_state_dict_from_url |
| |
| download_urls = { |
| 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth', |
| 'mobilenetv2' : 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', |
| 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth' |
| } |
| url = download_urls[backbone] |
| |
| if not os.path.exists(model_dir): |
| os.makedirs(model_dir) |
| load_state_dict_from_url(url, model_dir) |