File size: 969 Bytes
d218927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# scripts/evaluate.py

import torch
from model.lightning_unet import LightningUNet
from model.config import UNetConfig
from data.datamodule import SegmentationDataModule

def evaluate(model, dataloader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for images, masks in dataloader:
            outputs = model(images)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, masks)
            total_loss += loss.item()
    return total_loss / len(dataloader)

if __name__ == '__main__':
    # Configuraciones
    config = UNetConfig()
    data_dir = "data"

    # Carga el modelo
    model = LightningUNet(config)
    model.load_state_dict(torch.load('weights/unet_weights.pth'))

    # Prepara los datos
    datamodule = SegmentationDataModule(data_dir, config)
    datamodule.setup()

    # Evalúa el modelo
    val_loss = evaluate(model, datamodule.val_dataloader())
    print(f"Validation Loss: {val_loss:.4f}")