hf_segm_test / scripts /evaluate.py
ryali93's picture
first model version
d218927
# scripts/evaluate.py
import torch
from model.lightning_unet import LightningUNet
from model.config import UNetConfig
from data.datamodule import SegmentationDataModule
def evaluate(model, dataloader):
model.eval()
total_loss = 0.0
with torch.no_grad():
for images, masks in dataloader:
outputs = model(images)
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, masks)
total_loss += loss.item()
return total_loss / len(dataloader)
if __name__ == '__main__':
# Configuraciones
config = UNetConfig()
data_dir = "data"
# Carga el modelo
model = LightningUNet(config)
model.load_state_dict(torch.load('weights/unet_weights.pth'))
# Prepara los datos
datamodule = SegmentationDataModule(data_dir, config)
datamodule.setup()
# Evalúa el modelo
val_loss = evaluate(model, datamodule.val_dataloader())
print(f"Validation Loss: {val_loss:.4f}")