# scripts/evaluate.py import torch from model.lightning_unet import LightningUNet from model.config import UNetConfig from data.datamodule import SegmentationDataModule def evaluate(model, dataloader): model.eval() total_loss = 0.0 with torch.no_grad(): for images, masks in dataloader: outputs = model(images) loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, masks) total_loss += loss.item() return total_loss / len(dataloader) if __name__ == '__main__': # Configuraciones config = UNetConfig() data_dir = "data" # Carga el modelo model = LightningUNet(config) model.load_state_dict(torch.load('weights/unet_weights.pth')) # Prepara los datos datamodule = SegmentationDataModule(data_dir, config) datamodule.setup() # EvalĂșa el modelo val_loss = evaluate(model, datamodule.val_dataloader()) print(f"Validation Loss: {val_loss:.4f}")