File size: 764 Bytes
d218927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# scripts/train.py
import sys
sys.path.append("/home/ryali93/Desktop/projects/hf_segm")
import pytorch_lightning as pl
from model.lightning_unet import LightningUNet
from model.config import UNetConfig
from data.datamodule import SegmentationDataModule
import torch

if __name__ == '__main__':
    # Configuraciones
    config = UNetConfig()
    data_dir = "data"

    # Inicializa el modelo y el DataModule
    model = LightningUNet(config)
    datamodule = SegmentationDataModule(data_dir, config)

    # Entrena el modelo
    trainer = pl.Trainer(max_epochs=10)  # Puedes añadir más configuraciones al Trainer si lo necesitas
    trainer.fit(model, datamodule)

    # Guarda los pesos del modelo
    torch.save(model.state_dict(), 'weights/unet_weights.pth')