|
|
|
|
|
import sys |
|
|
sys.path.append("/home/ryali93/Desktop/projects/hf_segm") |
|
|
import pytorch_lightning as pl |
|
|
from model.lightning_unet import LightningUNet |
|
|
from model.config import UNetConfig |
|
|
from data.datamodule import SegmentationDataModule |
|
|
import torch |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
config = UNetConfig() |
|
|
data_dir = "data" |
|
|
|
|
|
|
|
|
model = LightningUNet(config) |
|
|
datamodule = SegmentationDataModule(data_dir, config) |
|
|
|
|
|
|
|
|
trainer = pl.Trainer(max_epochs=10) |
|
|
trainer.fit(model, datamodule) |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), 'weights/unet_weights.pth') |
|
|
|