ryali93's picture
first model version
d218927
# scripts/train.py
import sys
sys.path.append("/home/ryali93/Desktop/projects/hf_segm")
import pytorch_lightning as pl
from model.lightning_unet import LightningUNet
from model.config import UNetConfig
from data.datamodule import SegmentationDataModule
import torch
if __name__ == '__main__':
# Configuraciones
config = UNetConfig()
data_dir = "data"
# Inicializa el modelo y el DataModule
model = LightningUNet(config)
datamodule = SegmentationDataModule(data_dir, config)
# Entrena el modelo
trainer = pl.Trainer(max_epochs=10) # Puedes añadir más configuraciones al Trainer si lo necesitas
trainer.fit(model, datamodule)
# Guarda los pesos del modelo
torch.save(model.state_dict(), 'weights/unet_weights.pth')