# scripts/train.py import sys sys.path.append("/home/ryali93/Desktop/projects/hf_segm") import pytorch_lightning as pl from model.lightning_unet import LightningUNet from model.config import UNetConfig from data.datamodule import SegmentationDataModule import torch if __name__ == '__main__': # Configuraciones config = UNetConfig() data_dir = "data" # Inicializa el modelo y el DataModule model = LightningUNet(config) datamodule = SegmentationDataModule(data_dir, config) # Entrena el modelo trainer = pl.Trainer(max_epochs=10) # Puedes añadir más configuraciones al Trainer si lo necesitas trainer.fit(model, datamodule) # Guarda los pesos del modelo torch.save(model.state_dict(), 'weights/unet_weights.pth')