Spaces:
Running
Running
| """ | |
| utils/visualization.py | |
| Result visualisation utilities: | |
| - GradCAM heatmap overlay | |
| - FFT spectrum display | |
| - Result card with confidence bars | |
| """ | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| from matplotlib.figure import Figure | |
| from PIL import Image | |
| from typing import Dict, Tuple | |
| import io | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GradCAM overlay | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def overlay_gradcam( | |
| image: Image.Image, | |
| heatmap: np.ndarray, | |
| alpha: float = 0.5, | |
| colormap: str = "jet", | |
| ) -> Image.Image: | |
| """ | |
| Overlay a GradCAM heatmap on top of the original image. | |
| Args: | |
| image: original PIL Image | |
| heatmap: 2D np.ndarray [H, W] in [0, 1] | |
| alpha: overlay opacity | |
| colormap: matplotlib colormap name | |
| Returns: | |
| PIL Image with heatmap overlay | |
| """ | |
| # Resize heatmap to match image | |
| h, w = image.size[1], image.size[0] | |
| hm_pil = Image.fromarray((heatmap * 255).astype(np.uint8)) | |
| hm_pil = hm_pil.resize((w, h), Image.BILINEAR) | |
| hm_arr = np.array(hm_pil) / 255.0 | |
| # Apply colormap | |
| cmap = cm.get_cmap(colormap) | |
| colored = (cmap(hm_arr)[:, :, :3] * 255).astype(np.uint8) | |
| heat_pil = Image.fromarray(colored) | |
| # Blend | |
| img_rgb = image.convert("RGB") | |
| blended = Image.blend(img_rgb, heat_pil, alpha) | |
| return blended | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FFT spectrum visualisation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def visualize_fft_spectrum( | |
| spectrum: np.ndarray, | |
| title: str = "FFT Frequency Spectrum", | |
| ) -> Image.Image: | |
| """ | |
| Render the log-magnitude FFT spectrum as a PIL Image. | |
| Args: | |
| spectrum: 2D np.ndarray [H, W] in [0, 1] | |
| title: plot title | |
| Returns: | |
| PIL Image of the spectrum plot | |
| """ | |
| fig, ax = plt.subplots(1, 1, figsize=(5, 5), facecolor="#1a1a2e") | |
| ax.imshow(spectrum, cmap="plasma", interpolation="bilinear") | |
| ax.set_title(title, color="white", fontsize=12, pad=10) | |
| ax.axis("off") | |
| fig.tight_layout(pad=0.5) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=120, bbox_inches="tight", | |
| facecolor="#1a1a2e") | |
| buf.seek(0) | |
| plt.close(fig) | |
| return Image.open(buf).copy() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Result card | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create_result_card(result: Dict) -> Image.Image: | |
| """ | |
| Create a rich result card matplotlib figure with: | |
| - Large REAL / FAKE badge | |
| - Ensemble confidence gauge | |
| - Per-model score bars (CLIP, CNN, Frequency) | |
| Returns: | |
| PIL Image | |
| """ | |
| label = result["label"] | |
| confidence = result["confidence"] | |
| scores = result["scores"] | |
| weights = result["weights"] | |
| label_color = "#ff4c4c" if label == "FAKE" else "#00e676" | |
| bg = "#0f0f1a" | |
| card_bg = "#1a1a2e" | |
| text_color = "#e0e0e0" | |
| fig = plt.figure(figsize=(8, 5), facecolor=bg) | |
| # ββ Layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.4, left=0.1, right=0.95, | |
| top=0.88, bottom=0.08) | |
| ax_badge = fig.add_subplot(gs[0, 0]) | |
| ax_gauge = fig.add_subplot(gs[0, 1]) | |
| ax_bars = fig.add_subplot(gs[1, :]) | |
| # ββ Badge ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ax_badge.set_facecolor(card_bg) | |
| ax_badge.text(0.5, 0.6, label, | |
| color=label_color, fontsize=38, fontweight="bold", | |
| ha="center", va="center", transform=ax_badge.transAxes, | |
| fontfamily="monospace") | |
| ax_badge.text(0.5, 0.22, f"{confidence*100:.1f}% confident", | |
| color=text_color, fontsize=11, | |
| ha="center", va="center", transform=ax_badge.transAxes) | |
| ax_badge.set_xlim(0, 1); ax_badge.set_ylim(0, 1) | |
| ax_badge.axis("off") | |
| # ββ Gauge (pie / donut) ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ax_gauge.set_facecolor(card_bg) | |
| fake_prob = result["fake_prob"] | |
| real_prob = result["real_prob"] | |
| wedges, _ = ax_gauge.pie( | |
| [real_prob, fake_prob], | |
| colors = ["#00e676", "#ff4c4c"], | |
| startangle = 90, | |
| wedgeprops = dict(width=0.45, edgecolor=bg, linewidth=2), | |
| counterclock = False, | |
| ) | |
| ax_gauge.text(0, 0, f"{fake_prob*100:.0f}%\nFAKE", | |
| color=text_color, fontsize=10, ha="center", va="center", | |
| fontweight="bold") | |
| ax_gauge.set_title("Fake Probability", color=text_color, | |
| fontsize=10, pad=6) | |
| # ββ Per-model score bars ββββββββββββββββββββββββββββββββββββββββββββββ | |
| ax_bars.set_facecolor(card_bg) | |
| model_names = ["CLIP\n(zero-shot)", "CNN\n(EfficientNet)", "Frequency\n(DCT/FFT)", "Ensemble\n(combined)"] | |
| model_keys = ["clip", "cnn", "frequency", "ensemble"] | |
| model_vals = [scores.get(k, 0.5) for k in model_keys] | |
| bar_colors = [("#ff4c4c" if v >= 0.5 else "#00e676") for v in model_vals] | |
| bar_colors[-1] = "#7c4dff" # ensemble always purple | |
| bars = ax_bars.barh(model_names, model_vals, color=bar_colors, | |
| edgecolor=bg, height=0.55) | |
| ax_bars.axvline(0.5, color="#888", linewidth=1.2, linestyle="--", alpha=0.7) | |
| ax_bars.set_xlim(0, 1) | |
| ax_bars.set_xlabel("Fake Probability β", color=text_color, fontsize=9) | |
| ax_bars.tick_params(colors=text_color, labelsize=9) | |
| for spine in ax_bars.spines.values(): | |
| spine.set_edgecolor("#333") | |
| ax_bars.set_facecolor(card_bg) | |
| ax_bars.xaxis.label.set_color(text_color) | |
| # Value labels on bars | |
| for bar, val in zip(bars, model_vals): | |
| ax_bars.text( | |
| min(val + 0.02, 0.95), bar.get_y() + bar.get_height() / 2, | |
| f"{val*100:.1f}%", va="center", color=text_color, fontsize=9, | |
| ) | |
| ax_bars.set_title("Model Score Breakdown", color=text_color, fontsize=10, pad=6) | |
| # Title | |
| fig.text(0.5, 0.96, "π Image Authenticity Detection Report", | |
| color=text_color, fontsize=13, ha="center", va="top", | |
| fontweight="bold") | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor=bg) | |
| buf.seek(0) | |
| plt.close(fig) | |
| return Image.open(buf).copy() | |