Spaces:
Runtime error
Runtime error
| from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN | |
| import os | |
| import torch | |
| import json | |
| from base import TransformationType, ModelBackbone, TrainingObjective | |
| from torchvision import transforms | |
| import torchvision | |
| import torch.nn as nn | |
| def save_model(model, config_json,model_dir = None): | |
| if model_dir is None: | |
| model_basedir = MODEL_DIR | |
| models_present_in_dir = os.listdir(model_basedir) | |
| model_dir_name = 'model_{}'.format(len(models_present_in_dir)) | |
| model_dir = os.path.join(model_basedir, model_dir_name) | |
| os.mkdir(model_dir) | |
| model_path = os.path.join(model_dir, 'model.pth') | |
| torch.save(model.state_dict(), model_path) | |
| config_path = os.path.join(model_dir, 'config.json') | |
| # import pdb; pdb.set_trace() | |
| with open(config_path, 'w') as f: | |
| json.dump(config_json, f) | |
| return model_dir | |
| def get_transforms_to_apply_(transformation_type, config_json = None): | |
| if config_json: | |
| model_input_size = config_json['MODEL_INPUT_SIZE'] | |
| else: | |
| model_input_size = MODEL_INPUT_SIZE | |
| if transformation_type == TransformationType.RESIZE: | |
| return transforms.Resize(model_input_size) | |
| elif transformation_type == TransformationType.TO_TENSOR: | |
| return transforms.ToTensor() | |
| elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP: | |
| return transforms.RandomHorizontalFlip(p=0.5) | |
| elif transformation_type == TransformationType.NORMALIZE: | |
| return transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| elif transformation_type == TransformationType.RANDOM_ROTATION: | |
| return transforms.RandomRotation(degrees=10) | |
| elif transformation_type == TransformationType.RANDOM_CLIP: | |
| return transforms.RandomCrop(model_input_size) | |
| else: | |
| raise Exception("Invalid transformation type") | |
| def get_transforms_to_apply(): | |
| transforms_to_apply = [] | |
| for transform in TRANSFORMS_TO_APPLY: | |
| transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform])) | |
| return transforms.Compose(transforms_to_apply) | |
| def get_model_architecture(config_json = None): | |
| if config_json: | |
| model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']] | |
| model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']] | |
| else: | |
| model_backbone = MODEL_BACKBONE | |
| model_objective = MODEL_OBJECTIVE | |
| if model_backbone == ModelBackbone.EFFICIENT_NET_B0: | |
| if model_objective == TrainingObjective.REGRESSION: | |
| model = torchvision.models.efficientnet_b0(pretrained=True) | |
| model.classifier[1] = nn.Sequential( | |
| nn.Linear(model.classifier[1].in_features, 2048), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(2048, 1), | |
| ) | |
| else: | |
| raise Exception("Invalid model objective") | |
| else: | |
| raise Exception("Invalid model backbone") | |
| return model | |
| def get_training_params(model): | |
| training_params = [] | |
| if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0: | |
| if LAST_N_LAYERS_TO_TRAIN > 0: | |
| for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters(): | |
| param.requires_grad = False | |
| for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters(): | |
| training_params.append(param) | |
| for param in model.classifier[1].parameters(): | |
| training_params.append(param) | |
| else: | |
| raise Exception("Invalid model backbone") | |
| return training_params | |
| def get_criterion(): | |
| if MODEL_OBJECTIVE == TrainingObjective.REGRESSION: | |
| criterion = nn.MSELoss() | |
| else: | |
| raise Exception("Invalid model objective") | |
| return criterion | |