File size: 764 Bytes
d218927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# scripts/train.py
import sys
sys.path.append("/home/ryali93/Desktop/projects/hf_segm")
import pytorch_lightning as pl
from model.lightning_unet import LightningUNet
from model.config import UNetConfig
from data.datamodule import SegmentationDataModule
import torch
if __name__ == '__main__':
# Configuraciones
config = UNetConfig()
data_dir = "data"
# Inicializa el modelo y el DataModule
model = LightningUNet(config)
datamodule = SegmentationDataModule(data_dir, config)
# Entrena el modelo
trainer = pl.Trainer(max_epochs=10) # Puedes añadir más configuraciones al Trainer si lo necesitas
trainer.fit(model, datamodule)
# Guarda los pesos del modelo
torch.save(model.state_dict(), 'weights/unet_weights.pth')
|