Spaces:
Sleeping
Sleeping
| """ | |
| Visualization utilities for anomaly detection results | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| from PIL import Image | |
| from typing import Tuple, Optional | |
| import config | |
| def apply_heatmap(image: np.ndarray, | |
| anomaly_map: np.ndarray, | |
| alpha: float = 0.4, | |
| colormap: str = "jet") -> np.ndarray: | |
| """ | |
| Overlay anomaly heatmap on original image | |
| Args: | |
| image: Original image [H, W, 3] in RGB, range [0, 255] | |
| anomaly_map: Anomaly map [H', W'] normalized to [0, 1] | |
| alpha: Overlay transparency | |
| colormap: Matplotlib colormap name | |
| Returns: | |
| Overlayed image [H, W, 3] in RGB | |
| """ | |
| H, W = image.shape[:2] | |
| # Resize anomaly map to match image size | |
| anomaly_map_resized = cv2.resize(anomaly_map, (W, H)) | |
| # Normalize to [0, 1] | |
| anomaly_map_norm = (anomaly_map_resized - anomaly_map_resized.min()) / \ | |
| (anomaly_map_resized.max() - anomaly_map_resized.min() + 1e-8) | |
| # Apply colormap | |
| cmap = cm.get_cmap(colormap) | |
| heatmap = cmap(anomaly_map_norm)[:, :, :3] # Remove alpha channel | |
| heatmap = (heatmap * 255).astype(np.uint8) | |
| # Ensure image is uint8 | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| # Blend image and heatmap | |
| overlayed = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0) | |
| return overlayed | |
| def create_result_visualization(original_image: Image.Image, | |
| anomaly_score: float, | |
| anomaly_map: np.ndarray, | |
| threshold: float = 0.5, | |
| ground_truth: Optional[np.ndarray] = None) -> Image.Image: | |
| """ | |
| Create comprehensive result visualization | |
| Args: | |
| original_image: PIL Image | |
| anomaly_score: Image-level anomaly score | |
| anomaly_map: Pixel-level anomaly map | |
| threshold: Decision threshold | |
| ground_truth: Optional ground truth mask | |
| Returns: | |
| Combined visualization as PIL Image | |
| """ | |
| # Convert to numpy array | |
| img_np = np.array(original_image) | |
| # Create figure with subplots | |
| n_cols = 4 if ground_truth is not None else 3 | |
| fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 4, 4)) | |
| # Original image | |
| axes[0].imshow(img_np) | |
| axes[0].set_title("Original Image") | |
| axes[0].axis('off') | |
| # Anomaly heatmap | |
| heatmap_overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA) | |
| axes[1].imshow(heatmap_overlay) | |
| prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL" | |
| color = "red" if anomaly_score > threshold else "green" | |
| axes[1].set_title(f"Prediction: {prediction}\nScore: {anomaly_score:.3f}", | |
| color=color, fontweight='bold') | |
| axes[1].axis('off') | |
| # Raw anomaly map | |
| im = axes[2].imshow(anomaly_map, cmap=config.HEATMAP_COLORMAP) | |
| axes[2].set_title("Anomaly Map") | |
| axes[2].axis('off') | |
| plt.colorbar(im, ax=axes[2], fraction=0.046) | |
| # Ground truth (if available) | |
| if ground_truth is not None: | |
| axes[3].imshow(ground_truth, cmap='gray') | |
| axes[3].set_title("Ground Truth") | |
| axes[3].axis('off') | |
| plt.tight_layout() | |
| # Convert to PIL Image | |
| fig.canvas.draw() | |
| vis_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| vis_np = vis_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| vis_pil = Image.fromarray(vis_np) | |
| plt.close(fig) | |
| return vis_pil | |
| def plot_roc_curve(fpr: np.ndarray, tpr: np.ndarray, auc: float, | |
| save_path: Optional[str] = None): | |
| """Plot ROC curve""" | |
| plt.figure(figsize=(8, 6)) | |
| plt.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})') | |
| plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random') | |
| plt.xlabel('False Positive Rate', fontsize=12) | |
| plt.ylabel('True Positive Rate', fontsize=12) | |
| plt.title('ROC Curve - Image-Level Anomaly Detection', fontsize=14) | |
| plt.legend(fontsize=11) | |
| plt.grid(alpha=0.3) | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"ROC curve saved to {save_path}") | |
| plt.close() | |
| def save_prediction(image: Image.Image, | |
| anomaly_score: float, | |
| anomaly_map: np.ndarray, | |
| save_path: str, | |
| threshold: float = 0.5): | |
| """Save prediction result with overlay""" | |
| img_np = np.array(image) | |
| overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA) | |
| # Add text annotation | |
| prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL" | |
| color = (255, 0, 0) if anomaly_score > threshold else (0, 255, 0) | |
| cv2.putText(overlay, f"{prediction} ({anomaly_score:.3f})", | |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, color, 2, cv2.LINE_AA) | |
| # Save | |
| Image.fromarray(overlay).save(save_path) | |