Ameya729's picture
Upload 474 files
56ec9ba verified
"""
Visualization utilities for anomaly detection results
"""
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
from typing import Tuple, Optional
import config
def apply_heatmap(image: np.ndarray,
anomaly_map: np.ndarray,
alpha: float = 0.4,
colormap: str = "jet") -> np.ndarray:
"""
Overlay anomaly heatmap on original image
Args:
image: Original image [H, W, 3] in RGB, range [0, 255]
anomaly_map: Anomaly map [H', W'] normalized to [0, 1]
alpha: Overlay transparency
colormap: Matplotlib colormap name
Returns:
Overlayed image [H, W, 3] in RGB
"""
H, W = image.shape[:2]
# Resize anomaly map to match image size
anomaly_map_resized = cv2.resize(anomaly_map, (W, H))
# Normalize to [0, 1]
anomaly_map_norm = (anomaly_map_resized - anomaly_map_resized.min()) / \
(anomaly_map_resized.max() - anomaly_map_resized.min() + 1e-8)
# Apply colormap
cmap = cm.get_cmap(colormap)
heatmap = cmap(anomaly_map_norm)[:, :, :3] # Remove alpha channel
heatmap = (heatmap * 255).astype(np.uint8)
# Ensure image is uint8
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Blend image and heatmap
overlayed = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)
return overlayed
def create_result_visualization(original_image: Image.Image,
anomaly_score: float,
anomaly_map: np.ndarray,
threshold: float = 0.5,
ground_truth: Optional[np.ndarray] = None) -> Image.Image:
"""
Create comprehensive result visualization
Args:
original_image: PIL Image
anomaly_score: Image-level anomaly score
anomaly_map: Pixel-level anomaly map
threshold: Decision threshold
ground_truth: Optional ground truth mask
Returns:
Combined visualization as PIL Image
"""
# Convert to numpy array
img_np = np.array(original_image)
# Create figure with subplots
n_cols = 4 if ground_truth is not None else 3
fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 4, 4))
# Original image
axes[0].imshow(img_np)
axes[0].set_title("Original Image")
axes[0].axis('off')
# Anomaly heatmap
heatmap_overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA)
axes[1].imshow(heatmap_overlay)
prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL"
color = "red" if anomaly_score > threshold else "green"
axes[1].set_title(f"Prediction: {prediction}\nScore: {anomaly_score:.3f}",
color=color, fontweight='bold')
axes[1].axis('off')
# Raw anomaly map
im = axes[2].imshow(anomaly_map, cmap=config.HEATMAP_COLORMAP)
axes[2].set_title("Anomaly Map")
axes[2].axis('off')
plt.colorbar(im, ax=axes[2], fraction=0.046)
# Ground truth (if available)
if ground_truth is not None:
axes[3].imshow(ground_truth, cmap='gray')
axes[3].set_title("Ground Truth")
axes[3].axis('off')
plt.tight_layout()
# Convert to PIL Image
fig.canvas.draw()
vis_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
vis_np = vis_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
vis_pil = Image.fromarray(vis_np)
plt.close(fig)
return vis_pil
def plot_roc_curve(fpr: np.ndarray, tpr: np.ndarray, auc: float,
save_path: Optional[str] = None):
"""Plot ROC curve"""
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Image-Level Anomaly Detection', fontsize=14)
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"ROC curve saved to {save_path}")
plt.close()
def save_prediction(image: Image.Image,
anomaly_score: float,
anomaly_map: np.ndarray,
save_path: str,
threshold: float = 0.5):
"""Save prediction result with overlay"""
img_np = np.array(image)
overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA)
# Add text annotation
prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL"
color = (255, 0, 0) if anomaly_score > threshold else (0, 255, 0)
cv2.putText(overlay, f"{prediction} ({anomaly_score:.3f})",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, color, 2, cv2.LINE_AA)
# Save
Image.fromarray(overlay).save(save_path)