Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from train_autoencoder import ConvAutoencoder | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ConvAutoencoder().to(device) | |
| model.load_state_dict(torch.load("models/autoencoder.pth", map_location=device)) | |
| model.eval() | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def detect_hazard(img_pil): | |
| """ | |
| Returns: (is_anomalous: bool, recon_error: float, reconstructed_img: PIL) | |
| """ | |
| img_tensor = transform(img_pil).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| recon = model(img_tensor) | |
| loss_fn = torch.nn.MSELoss() | |
| recon_error = loss_fn(recon, img_tensor).item() | |
| # Convert recon to PIL | |
| recon_pil = recon.cpu().squeeze(0) | |
| recon_pil = T.ToPILImage()(recon_pil) | |
| # Threshold (adjust based on validation) | |
| is_anomalous = recon_error > 0.02 | |
| return is_anomalous, recon_error, recon_pil | |