Spaces:
Runtime error
Runtime error
| import math | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .utils_aug import resize, center_crop | |
| #---------------------------------------------------------# | |
| # 将图像转换成RGB图像,防止灰度图在预测时报错。 | |
| # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB | |
| #---------------------------------------------------------# | |
| def cvtColor(image): | |
| if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: | |
| return image | |
| else: | |
| image = image.convert('RGB') | |
| return image | |
| #---------------------------------------------------# | |
| # 对输入图像进行resize | |
| #---------------------------------------------------# | |
| def letterbox_image(image, size, letterbox_image): | |
| w, h = size | |
| iw, ih = image.size | |
| if letterbox_image: | |
| '''resize image with unchanged aspect ratio using padding''' | |
| scale = min(w/iw, h/ih) | |
| nw = int(iw*scale) | |
| nh = int(ih*scale) | |
| image = image.resize((nw,nh), Image.BICUBIC) | |
| new_image = Image.new('RGB', size, (128,128,128)) | |
| new_image.paste(image, ((w-nw)//2, (h-nh)//2)) | |
| else: | |
| if h == w: | |
| new_image = resize(image, h) | |
| else: | |
| new_image = resize(image, [h ,w]) | |
| new_image = center_crop(new_image, [h ,w]) | |
| return new_image | |
| #---------------------------------------------------# | |
| # 获得类 | |
| #---------------------------------------------------# | |
| def get_classes(classes_path): | |
| with open(classes_path, encoding='utf-8') as f: | |
| class_names = f.readlines() | |
| class_names = [c.strip() for c in class_names] | |
| return class_names, len(class_names) | |
| #----------------------------------------# | |
| # 预处理训练图片 | |
| #----------------------------------------# | |
| def preprocess_input(x): | |
| x /= 127.5 | |
| x -= 1. | |
| return x | |
| def show_config(**kwargs): | |
| print('Configurations:') | |
| print('-' * 70) | |
| print('|%25s | %40s|' % ('keys', 'values')) | |
| print('-' * 70) | |
| for key, value in kwargs.items(): | |
| print('|%25s | %40s|' % (str(key), str(value))) | |
| print('-' * 70) | |
| #---------------------------------------------------# | |
| # 获得学习率 | |
| #---------------------------------------------------# | |
| def get_lr(optimizer): | |
| for param_group in optimizer.param_groups: | |
| return param_group['lr'] | |
| def weights_init(net, init_type='normal', init_gain=0.02): | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if hasattr(m, 'weight') and classname.find('Conv') != -1: | |
| if init_type == 'normal': | |
| torch.nn.init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == 'xavier': | |
| torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == 'kaiming': | |
| torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
| elif classname.find('BatchNorm2d') != -1: | |
| torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | |
| torch.nn.init.constant_(m.bias.data, 0.0) | |
| print('initialize network with %s type' % init_type) | |
| net.apply(init_func) | |
| def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): | |
| def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): | |
| if iters <= warmup_total_iters: | |
| # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start | |
| lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start | |
| elif iters >= total_iters - no_aug_iter: | |
| lr = min_lr | |
| else: | |
| lr = min_lr + 0.5 * (lr - min_lr) * ( | |
| 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) | |
| ) | |
| return lr | |
| def step_lr(lr, decay_rate, step_size, iters): | |
| if step_size < 1: | |
| raise ValueError("step_size must above 1.") | |
| n = iters // step_size | |
| out_lr = lr * decay_rate ** n | |
| return out_lr | |
| if lr_decay_type == "cos": | |
| warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) | |
| warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) | |
| no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) | |
| func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) | |
| else: | |
| decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) | |
| step_size = total_iters / step_num | |
| func = partial(step_lr, lr, decay_rate, step_size) | |
| return func | |
| def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): | |
| lr = lr_scheduler_func(epoch) | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| def download_weights(backbone, model_dir="./model_data"): | |
| import os | |
| from torch.hub import load_state_dict_from_url | |
| download_urls = { | |
| 'vgg16' : 'https://download.pytorch.org/models/vgg16-397923af.pth', | |
| 'mobilenet' : 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', | |
| 'resnet50' : 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | |
| 'vit' : 'https://github.com/bubbliiiing/classification-pytorch/releases/download/v1.0/vit-patch_16.pth' | |
| } | |
| url = download_urls[backbone] | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| load_state_dict_from_url(url, model_dir) | |