File size: 1,066 Bytes
56f7a23
4cf9e1d
56f7a23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import pytorch_lightning as L
from dataloader import AerialImageDataset
from train import UNet
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import torch

train_path = "/teamspace/studios/this_studio/Aerial-Segmentation/train"
val_path = "/teamspace/studios/this_studio/Aerial-Segmentation/val"

data_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()]
    )

train_dataset = AerialImageDataset(os.path.join(train_path, 'images'), os.path.join(train_path, 'masks'), transform=data_transform)
val_dataset = AerialImageDataset(os.path.join(val_path, 'images'), os.path.join(val_path, 'masks'), transform=data_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

model = UNet(n_channels=3, n_classes=6)

trainer = L.Trainer(max_epochs=100)
trainer.fit(model, train_loader, val_loader)

torch.save(model.state_dict(), "/teamspace/studios/this_studio/Aerial-Segmentation/model.pth")