Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import click | |
| import torch | |
| from sklearn.metrics import f1_score | |
| from torch.utils import data | |
| from utils import * | |
| from model import createDeepLabv3 | |
| from trainer import train_model | |
| def main(data_directory, exp_directory, epochs, batch_size): | |
| # Create the deeplabv3 resnet101 model which is pretrained on a subset | |
| # of COCO train2017, on the 20 categories that are present in the Pascal VOC dataset. | |
| model = createDeepLabv3() | |
| model.train() | |
| data_directory = Path(data_directory) | |
| # Create the experiment directory if not present | |
| exp_directory = Path(exp_directory) | |
| if not exp_directory.exists(): | |
| exp_directory.mkdir() | |
| # Specify the loss function | |
| criterion = torch.nn.MSELoss(reduction='mean') | |
| # Specify the optimizer with a lower learning rate | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| # Specify the evaluation metrics | |
| metrics = {'f1_score': f1_score, 'iou': iou} | |
| # Create the dataloader | |
| dataloaders = get_dataloader_single_folder( | |
| data_directory, batch_size=batch_size) | |
| _ = train_model(model, | |
| criterion, | |
| dataloaders, | |
| optimizer, | |
| bpath=exp_directory, | |
| metrics=metrics, | |
| num_epochs=epochs) | |
| # Save the trained model | |
| torch.save(model, exp_directory / 'weights.pt') | |
| if __name__ == "__main__": | |
| main() |