import torch import numpy as np from PIL import Image import torchvision.transforms as T from train_autoencoder import ConvAutoencoder device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ConvAutoencoder().to(device) model.load_state_dict(torch.load("models/autoencoder.pth", map_location=device)) model.eval() transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def detect_hazard(img_pil): """ Returns: (is_anomalous: bool, recon_error: float, reconstructed_img: PIL) """ img_tensor = transform(img_pil).unsqueeze(0).to(device) with torch.no_grad(): recon = model(img_tensor) loss_fn = torch.nn.MSELoss() recon_error = loss_fn(recon, img_tensor).item() # Convert recon to PIL recon_pil = recon.cpu().squeeze(0) recon_pil = T.ToPILImage()(recon_pil) # Threshold (adjust based on validation) is_anomalous = recon_error > 0.02 return is_anomalous, recon_error, recon_pil