Spaces:
Sleeping
Sleeping
| import os | |
| import pytorch_lightning as L | |
| from dataloader import AerialImageDataset | |
| from train import UNet | |
| from torch.utils.data import DataLoader | |
| from torchvision.transforms import transforms | |
| import torch | |
| train_path = "/teamspace/studios/this_studio/Aerial-Segmentation/train" | |
| val_path = "/teamspace/studios/this_studio/Aerial-Segmentation/val" | |
| data_transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor()] | |
| ) | |
| train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform) | |
| val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform) | |
| train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) | |
| model = UNet(n_channels=3, n_classes=6) | |
| trainer = L.Trainer(max_epochs=100) | |
| trainer.fit(model, train_loader, val_loader) | |
| torch.save(model.state_dict(), "/teamspace/studios/this_studio/Aerial-Segmentation/model.pth") | |