File size: 2,474 Bytes
398659b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
"""
Visualization utilities for compression results.
"""
import numpy as np
import cv2
from PIL import Image
from typing import Tuple
import matplotlib.pyplot as plt
def highlight_roi(
image: Image.Image,
mask: np.ndarray,
alpha: float = 0.3,
color: Tuple[int, int, int] = (0, 255, 0)
) -> Image.Image:
"""
Highlight ROI regions in image with colored overlay.
Args:
image: PIL Image
mask: Binary mask (H, W)
alpha: Overlay transparency (0-1)
color: RGB color tuple for ROI highlight
Returns:
Image with ROI highlighted
"""
img_array = np.array(image)
# Create colored overlay
overlay = img_array.copy()
overlay[mask > 0.5] = color
# Blend
result = cv2.addWeighted(img_array, 1 - alpha, overlay, alpha, 0)
return Image.fromarray(result)
def create_comparison_grid(
original: Image.Image,
compressed: Image.Image,
mask: np.ndarray,
bpp: float,
sigma: float,
lambda_val: float,
highlight: bool = True
) -> Image.Image:
"""
Create side-by-side comparison of original and compressed images.
Args:
original: Original PIL Image
compressed: Compressed PIL Image
mask: Binary mask used
bpp: Bits per pixel
sigma: Sigma value used
lambda_val: Lambda value used
highlight: Whether to show ROI overlay
Returns:
Combined comparison image
"""
fig, axes = plt.subplots(1, 3 if highlight else 2, figsize=(15 if highlight else 10, 5))
# Original
axes[0].imshow(original)
axes[0].set_title('Original', fontsize=14)
axes[0].axis('off')
# Compressed
axes[1].imshow(compressed)
axes[1].set_title(f'Compressed (σ={sigma:.2f}, λ={lambda_val}, BPP={bpp:.3f})', fontsize=14)
axes[1].axis('off')
# ROI overlay
if highlight:
highlighted = highlight_roi(original, mask, alpha=0.4, color=(0, 255, 0))
axes[2].imshow(highlighted)
axes[2].set_title('ROI Mask (green)', fontsize=14)
axes[2].axis('off')
plt.tight_layout()
# Convert to PIL Image
fig.canvas.draw()
img_array = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (4,))
img_array = img_array[:, :, :3] # Remove alpha channel
plt.close(fig)
return Image.fromarray(img_array)
|