ArchEGraph's picture
Initial commit after history reset
eecbf34
from __future__ import annotations
from pathlib import Path
import matplotlib
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
DEFAULT_WEATHER_COLUMNS = [
"dry_bulb",
"dew_point",
"relative_humidity",
"global_horizontal_radiation",
"direct_normal_radiation",
"diffuse_horizontal_radiation",
"wind_speed",
]
WEATHER_UNITS = {
"dry_bulb": "degC",
"dew_point": "degC",
"relative_humidity": "%",
"global_horizontal_radiation": "W/m2",
"direct_normal_radiation": "W/m2",
"diffuse_horizontal_radiation": "W/m2",
"wind_direction": "deg",
}
FIG_SIZE_IN = 4.2
AXIS_LABEL_FONT_SIZE = 8
TICK_LABEL_FONT_SIZE = 8
INPLOT_LABEL_FONT_SIZE = 10
def _to_hourly_feature(values: np.ndarray) -> np.ndarray:
arr = np.asarray(values)
if arr.ndim != 2:
raise ValueError(f"Expected 2D weather matrix, got shape={arr.shape}")
if arr.shape[0] == 8760:
return np.asarray(arr, dtype=float)
if arr.shape[1] == 8760:
return np.asarray(arr.T, dtype=float)
raise ValueError(f"Neither axis is 8760 for weather matrix: shape={arr.shape}")
def _decode_columns(columns_arr: np.ndarray | None, width: int) -> list[str]:
if columns_arr is None:
return [f"feature_{i}" for i in range(width)]
cols = np.asarray(columns_arr).reshape(-1)
names = [str(c, "utf-8") if isinstance(c, (bytes, np.bytes_)) else str(c) for c in cols]
if len(names) < width:
names.extend([f"feature_{i}" for i in range(len(names), width)])
return names[:width]
def _pick_weather_indices(column_names: list[str]) -> list[int]:
lower_to_idx = {name.lower(): idx for idx, name in enumerate(column_names)}
selected: list[int] = []
for name in DEFAULT_WEATHER_COLUMNS:
idx = lower_to_idx.get(name.lower())
if idx is not None:
selected.append(idx)
if len(selected) < 7:
for idx in range(len(column_names)):
if idx not in selected:
selected.append(idx)
if len(selected) == 7:
break
return selected
def _time_window(hourly: np.ndarray, start_hour: int, window_hours: int) -> tuple[np.ndarray, int, int]:
total = int(hourly.shape[0])
if total < 1:
raise ValueError("No hourly weather records found")
start_idx = max(0, min(total - 1, int(start_hour) - 1))
window = max(1, int(window_hours))
end_idx = min(total, start_idx + window)
if end_idx <= start_idx:
raise ValueError(f"Invalid weather window: start={start_hour}, hours={window_hours}")
return hourly[start_idx:end_idx, :], start_idx + 1, end_idx
def _major_ticks(length: int) -> list[int]:
if length <= 8:
return list(range(1, length + 1))
tick_count = 6
ticks = np.linspace(1, length, num=tick_count, dtype=int)
uniq = sorted(set(int(t) for t in ticks))
if uniq[-1] != length:
uniq.append(length)
return uniq
def _label_with_unit(name: str) -> str:
key = name.strip().lower().replace(" ", "_").replace("-", "_")
unit = WEATHER_UNITS.get(key)
if unit is None:
return name
return f"{name} ({unit})"
def visualize_weather(
weather_npz: str | Path,
output_png: str | Path,
*,
start_hour: int = 1,
window_hours: int = 24,
dpi: int = 220,
) -> Path:
"""Plot weather subplots from PACK weather npz (values + columns) in a selected time window."""
weather_npz = Path(weather_npz)
output_png = Path(output_png)
with np.load(weather_npz, allow_pickle=True) as data:
if "values" not in data:
keys = ", ".join(sorted(data.files))
raise KeyError(f"Missing key 'values' in {weather_npz}; keys=[{keys}]")
values = np.asarray(data["values"], dtype=float)
columns = np.asarray(data["columns"], dtype=object) if "columns" in data else None
hourly = _to_hourly_feature(values)
window, window_start, window_end = _time_window(hourly, start_hour=start_hour, window_hours=window_hours)
names = _decode_columns(columns, window.shape[1])
idx_list = _pick_weather_indices(names)
if len(idx_list) == 0:
raise ValueError(f"No weather series available in {weather_npz}")
fig, axes = plt.subplots(
len(idx_list),
1,
figsize=(FIG_SIZE_IN, FIG_SIZE_IN),
sharex=True,
gridspec_kw={"hspace": 0.0},
)
if len(idx_list) == 1:
axes = [axes]
x = np.arange(1, window.shape[0] + 1, dtype=int)
major_ticks = _major_ticks(window.shape[0])
for row_idx, feat_idx in enumerate(idx_list):
ax = axes[row_idx]
y = window[:, feat_idx]
ax.plot(x, y, linewidth=0.9, color="#4C72B0")
ax.set_ylabel("")
ax.text(
0.02,
0.86,
_label_with_unit(names[feat_idx]),
transform=ax.transAxes,
ha="left",
va="top",
rotation=0,
fontsize=INPLOT_LABEL_FONT_SIZE,
bbox={"facecolor": "white", "alpha": 0.65, "edgecolor": "none", "pad": 1.5},
)
ax.set_xticks(major_ticks)
ax.grid(axis="y", alpha=0.3, linewidth=0.5)
ax.grid(axis="x", alpha=0.22, linewidth=0.45)
ax.tick_params(axis="both", which="both", labelsize=TICK_LABEL_FONT_SIZE)
if row_idx < len(idx_list) - 1:
ax.tick_params(axis="x", which="both", labelbottom=False)
axes[-1].set_xlabel(f"hour index in window ({window_start}-{window_end})", fontsize=AXIS_LABEL_FONT_SIZE)
axes[-1].set_xlim(1, window.shape[0] + 0.5)
axes[-1].set_xticks(major_ticks)
axes[-1].set_xticklabels([str(t) for t in major_ticks], fontsize=TICK_LABEL_FONT_SIZE)
fig.subplots_adjust(left=0.18, right=0.94, bottom=0.14, top=0.98, hspace=0.0)
output_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_png, dpi=dpi)
plt.close(fig)
return output_png