Spaces:
Sleeping
Sleeping
| # visualizer.py | |
| import io | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") # non-interactive backend — required for server/Gradio | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| from PIL import Image | |
| # WHO/AEQ PM2.5 breakpoints (μg/m³) for color bands annotation | |
| PM25_LEVELS = [ | |
| (0, 15, "Good", "#00e400"), | |
| (15, 35, "Moderate", "#ffff00"), | |
| (35, 55, "Sensitive", "#ff7e00"), | |
| (55, 150, "Unhealthy", "#ff0000"), | |
| (150, 300, "Hazardous", "#7e0023"), | |
| ] | |
| def make_heatmap( | |
| data_2d: np.ndarray, | |
| title: str, | |
| vmin: float = 0.0, | |
| vmax: float = 200.0, | |
| lat: np.ndarray = None, | |
| lon: np.ndarray = None, | |
| figsize: tuple = (6, 5), | |
| dpi: int = 110, | |
| ) -> Image.Image: | |
| """ | |
| Render a PM2.5 spatial heatmap. | |
| Parameters | |
| ---------- | |
| data_2d : np.ndarray (H, W) | |
| PM2.5 values in μg/m³. | |
| title : str | |
| vmin, vmax : float | |
| Color scale range. | |
| lat, lon : np.ndarray (H, W) or None | |
| Geographic coordinates. If None, pixel indices are used. | |
| figsize, dpi : figure size and resolution. | |
| Returns | |
| ------- | |
| PIL.Image.Image — PNG image ready for Gradio. | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize, dpi=dpi) | |
| fig.patch.set_facecolor("#0f0f0f") | |
| ax.set_facecolor("#0f0f0f") | |
| # Extent for geographic axes | |
| if lat is not None and lon is not None: | |
| extent = [ | |
| float(lon.min()), float(lon.max()), | |
| float(lat.min()), float(lat.max()), | |
| ] | |
| xlabel, ylabel = "Longitude (°E)", "Latitude (°N)" | |
| else: | |
| extent = None | |
| xlabel, ylabel = "Grid X", "Grid Y" | |
| cmap = plt.get_cmap("RdYlGn_r") | |
| im = ax.imshow( | |
| data_2d, | |
| cmap = cmap, | |
| vmin = vmin, | |
| vmax = vmax, | |
| aspect = "auto", | |
| extent = extent, | |
| origin = "lower", | |
| ) | |
| # Colorbar | |
| cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.set_label("PM2.5 (μg/m³)", color="white", fontsize=8) | |
| cbar.ax.yaxis.set_tick_params(color="white", labelcolor="white") | |
| # Axes labels and ticks | |
| ax.set_xlabel(xlabel, color="white", fontsize=8) | |
| ax.set_ylabel(ylabel, color="white", fontsize=8) | |
| ax.tick_params(colors="white", labelsize=7) | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#444444") | |
| ax.set_title(title, color="white", fontsize=9, pad=8, fontweight="bold") | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", dpi=dpi, | |
| facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() # .copy() detaches from the BytesIO buffer | |
| def compute_stats(pred_2d: np.ndarray, input_2d: np.ndarray) -> dict: | |
| """ | |
| Compute summary statistics for a single (H, W) prediction frame. | |
| Returns a plain dict of human-readable strings. | |
| """ | |
| return { | |
| "Predicted Mean PM2.5": f"{float(pred_2d.mean()):.1f} μg/m³", | |
| "Predicted Max PM2.5": f"{float(pred_2d.max()):.1f} μg/m³", | |
| "Input Mean PM2.5": f"{float(input_2d.mean()):.1f} μg/m³", | |
| "High-Risk Pixels (>75 μg/m³)": str(int((pred_2d > 75).sum())), | |
| "Unhealthy Pixels (>150 μg/m³)": str(int((pred_2d > 150).sum())), | |
| "Change vs Input": f"{float(pred_2d.mean() - input_2d.mean()):+.1f} μg/m³", | |
| } | |