|
|
|
|
|
|
|
|
import torch |
|
|
from model.lightning_unet import LightningUNet |
|
|
from model.config import UNetConfig |
|
|
from data.datamodule import SegmentationDataModule |
|
|
|
|
|
def evaluate(model, dataloader): |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
with torch.no_grad(): |
|
|
for images, masks in dataloader: |
|
|
outputs = model(images) |
|
|
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, masks) |
|
|
total_loss += loss.item() |
|
|
return total_loss / len(dataloader) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
config = UNetConfig() |
|
|
data_dir = "data" |
|
|
|
|
|
|
|
|
model = LightningUNet(config) |
|
|
model.load_state_dict(torch.load('weights/unet_weights.pth')) |
|
|
|
|
|
|
|
|
datamodule = SegmentationDataModule(data_dir, config) |
|
|
datamodule.setup() |
|
|
|
|
|
|
|
|
val_loss = evaluate(model, datamodule.val_dataloader()) |
|
|
print(f"Validation Loss: {val_loss:.4f}") |
|
|
|