from __future__ import annotations from pathlib import Path import matplotlib import numpy as np matplotlib.use("Agg") import matplotlib.pyplot as plt FIG_SIZE_IN = 4.2 MAX_ENERGY_SUBPLOTS = 6 AXIS_LABEL_FONT_SIZE = 8 TICK_LABEL_FONT_SIZE = 8 INPLOT_LABEL_FONT_SIZE = 10 def _to_hourly_zone(values: np.ndarray) -> np.ndarray: arr = np.asarray(values) if arr.ndim != 2: raise ValueError(f"Expected 2D energy matrix, got shape={arr.shape}") if arr.shape[0] == 8760: return np.asarray(arr, dtype=float) if arr.shape[1] == 8760: return np.asarray(arr.T, dtype=float) raise ValueError(f"Neither axis is 8760 for energy matrix: shape={arr.shape}") def _decode_zone_names(columns_arr: np.ndarray | None, zone_count: int) -> list[str]: if columns_arr is None: return [f"zone_{i}" for i in range(zone_count)] cols = np.asarray(columns_arr).reshape(-1) names = [str(c, "utf-8") if isinstance(c, (bytes, np.bytes_)) else str(c) for c in cols] if len(names) < zone_count: names.extend([f"zone_{i}" for i in range(len(names), zone_count)]) return names[:zone_count] def _time_window(hourly_zone: np.ndarray, start_hour: int, window_hours: int) -> tuple[np.ndarray, int, int]: total = int(hourly_zone.shape[0]) if total < 1: raise ValueError("No hourly energy records found") start_idx = max(0, min(total - 1, int(start_hour) - 1)) window = max(1, int(window_hours)) end_idx = min(total, start_idx + window) if end_idx <= start_idx: raise ValueError(f"Invalid energy window: start={start_hour}, hours={window_hours}") return hourly_zone[start_idx:end_idx, :], start_idx + 1, end_idx def _major_ticks(length: int) -> list[int]: if length <= 8: return list(range(1, length + 1)) tick_count = 6 ticks = np.linspace(1, length, num=tick_count, dtype=int) uniq = sorted(set(int(t) for t in ticks)) if uniq[-1] != length: uniq.append(length) return uniq def visualize_energy( energy_npz: str | Path, output_png: str | Path, *, max_zones: int | None = None, zone_index: int | None = None, start_hour: int = 1, window_hours: int = 24, dpi: int = 220, ) -> Path: """Plot energy curves in compressed square layout with shared x-axis for a selected time window.""" energy_npz = Path(energy_npz) output_png = Path(output_png) with np.load(energy_npz, allow_pickle=True) as data: if "values" not in data: keys = ", ".join(sorted(data.files)) raise KeyError(f"Missing key 'values' in {energy_npz}; keys=[{keys}]") values = np.asarray(data["values"], dtype=float) columns = np.asarray(data["columns"], dtype=object) if "columns" in data else None hourly_zone = _to_hourly_zone(values) window, window_start, window_end = _time_window(hourly_zone, start_hour=start_hour, window_hours=window_hours) zone_count = window.shape[1] if zone_count < 1: raise ValueError(f"No zones found in {energy_npz}") names = _decode_zone_names(columns, zone_count) zone_indices = list(range(zone_count)) if zone_index is not None: zi = int(zone_index) if zi < 0 or zi >= zone_count: raise ValueError(f"zone_index out of range: {zi}, valid=[0, {zone_count - 1}]") zone_indices = [zi] elif max_zones is not None and max_zones > 0: zone_indices = zone_indices[: max_zones] if len(zone_indices) < 1: raise ValueError("No zones selected for plotting") plotted_zone_indices = zone_indices[:MAX_ENERGY_SUBPLOTS] omitted_count = len(zone_indices) - len(plotted_zone_indices) cmap = plt.get_cmap("tab20") x = np.arange(1, window.shape[0] + 1, dtype=int) major_ticks = _major_ticks(window.shape[0]) row_count = len(plotted_zone_indices) fig, axes = plt.subplots( row_count, 1, figsize=(FIG_SIZE_IN, FIG_SIZE_IN), sharex=True, gridspec_kw={"hspace": 0.0}, ) if row_count == 1: axes = [axes] for row_idx, zone_idx in enumerate(plotted_zone_indices): ax = axes[row_idx] color = cmap(row_idx % 20) ax.plot(x, window[:, zone_idx], color=color, linewidth=0.9, alpha=0.9) ax.set_ylabel("") ax.text( 0.02, 0.86, names[zone_idx], transform=ax.transAxes, ha="left", va="top", rotation=0, fontsize=INPLOT_LABEL_FONT_SIZE, bbox={"facecolor": "white", "alpha": 0.65, "edgecolor": "none", "pad": 1.5}, ) ax.set_xticks(major_ticks) ax.grid(axis="y", alpha=0.25, linewidth=0.5) ax.grid(axis="x", alpha=0.22, linewidth=0.45) ax.tick_params(axis="both", which="both", labelsize=TICK_LABEL_FONT_SIZE) if row_idx < row_count - 1: ax.tick_params(axis="x", which="both", labelbottom=False) if omitted_count > 0: axes[-1].text( 0.98, 0.86, f"... (+{omitted_count})", transform=axes[-1].transAxes, ha="right", va="top", fontsize=INPLOT_LABEL_FONT_SIZE, bbox={"facecolor": "white", "alpha": 0.65, "edgecolor": "none", "pad": 1.5}, ) axes[-1].set_xlabel(f"hour index in window ({window_start}-{window_end})", fontsize=AXIS_LABEL_FONT_SIZE) axes[-1].set_xlim(1, window.shape[0] + 0.5) axes[-1].set_xticks(major_ticks) axes[-1].set_xticklabels([str(t) for t in major_ticks], fontsize=TICK_LABEL_FONT_SIZE) fig.subplots_adjust(left=0.18, right=0.94, bottom=0.14, top=0.98, hspace=0.0) output_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_png, dpi=dpi) plt.close(fig) return output_png