FOXES / analysis /spatial_performance.py
griffingoodwin04's picture
refactor configuration files and update paths
0affdc2
"""
Flux-Weighted Error Heatmap on Solar Disk
==========================================
For each matched flux map, accumulates:
mae_sum[i,j] += flux[i,j] * |log10 error|
bias_sum[i,j] += flux[i,j] * log10 error
weight[i,j] += flux[i,j]
Then normalizes to get flux-weighted mean error per patch.
Usage
-----
python analysis/spatial_performance.py
Outputs
-------
analysis/flux_weighted_errors_t0.npz — accumulation cache
analysis/performance_heatmap_all.png
"""
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from concurrent.futures import ProcessPoolExecutor, as_completed
from matplotlib.colors import LogNorm
from tqdm import tqdm
from pathlib import Path
from cmap import Colormap
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from forecasting.inference.evaluation import setup_barlow_font
# ---------------------------------------------------------------------------
# Paths — override via CLI args or environment variables
# ---------------------------------------------------------------------------
FLUX_DIR = os.environ.get("FOXES_FLUX_DIR", "")
PREDICTIONS_CSV = os.environ.get("FOXES_PREDICTIONS_CSV", "")
OUT_DIR = Path(__file__).parent
GRID_SIZE = 64 # 512px / 8px patch size
BIN_SIZE = 1 # downsample factor (1 = full 64×64 resolution)
CROP_FACTOR = 1.1 # AIA images cropped at 1.1 solar radii
SOLAR_RADIUS_PATCHES = (GRID_SIZE / 2) / CROP_FACTOR # ≈ 29.1 patches
# Patches beyond ±PATCH_CROP_RADIUS from center (in original 64×64 patch units) are masked.
PATCH_CROP_RADIUS = 24
# Percentile cap for colorbar scaling (applied to non-NaN values).
# e.g. 99 clips the top 1% of values so detail in the bulk is visible.
VMAX_PERCENTILE = 99
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def normalize_ts(series: pd.Series) -> pd.Series:
return pd.to_datetime(
series.astype(str).str.replace("_", ":", regex=False), utc=False,
).dt.floor("s")
def _ts_key(fpath: str) -> str:
raw = os.path.basename(fpath).replace('.npy', '').replace('_', ':')
return pd.Timestamp(raw).floor('s').isoformat()
def load_predictions(predictions_csv: str) -> pd.DataFrame:
df = pd.read_csv(predictions_csv)
df["timestamp"] = normalize_ts(df["timestamp"])
df["log_pred"] = np.log10(df["predictions"])
df["log_gt"] = np.log10(df["groundtruth"])
df["log_error"] = df["log_pred"] - df["log_gt"]
df["log_abs_error"] = df["log_error"].abs()
print(f"Loaded {len(df)} predictions")
return df
# ---------------------------------------------------------------------------
# Heatmap accumulation
# ---------------------------------------------------------------------------
# NOTE: module-level for ProcessPoolExecutor (spawn on macOS)
def _accumulate_flux_map(args):
fpath, log_abs_error, log_error, bin_size = args
fmap = np.load(fpath).astype(np.float64)
active = fmap[fmap > 0]
if active.size == 0:
return None
fmap = np.where(fmap > 0, fmap, 0.0)
# Spatially bin before normalization — sum preserves relative log-flux within each bin
if bin_size > 1:
h, w = fmap.shape
bh, bw = h // bin_size, w // bin_size
fmap = fmap[:bh * bin_size, :bw * bin_size].reshape(bh, bin_size, bw, bin_size).sum(axis=(1, 3))
total = fmap.sum()
if total == 0:
return None
fmap = fmap / total # normalise: each timestamp contributes equal total weight
return fmap * log_abs_error, fmap * log_error, fmap
def _crop_mask(shape: tuple, bin_size: int, radius: int = PATCH_CROP_RADIUS) -> np.ndarray:
"""True for patches within ±radius original-grid patches from center (y-axis only)."""
n = shape[0]
r_binned = radius / bin_size
cy = (n - 1) / 2
y = np.ogrid[:n, :n][0]
return np.abs(y - cy) <= r_binned
def compute_flux_weighted_errors(flux_dir: str, df: pd.DataFrame, cache_path: Path,
bin_size: int = BIN_SIZE) -> dict:
cache_path = cache_path.with_stem(f"{cache_path.stem}_b{bin_size}")
if cache_path.exists():
print(f"Loading cached flux-weighted error maps from {cache_path}")
data = np.load(cache_path)
n = float(data['count'])
w = data['flux_distribution']
mask = _crop_mask(w.shape, bin_size)
mae = np.where(mask, data['mae_sum'] / w, np.nan) if n > 0 else np.full_like(w, np.nan)
bias = np.where(mask, data['bias_sum'] / w, np.nan) if n > 0 else np.full_like(w, np.nan)
return mae, bias, w
lookup = {}
for _, row in df.iterrows():
key = pd.Timestamp(row['timestamp']).floor('s').isoformat()
lookup[key] = (float(row['log_abs_error']), float(row['log_error']))
binned_grid = GRID_SIZE // bin_size
shape = (binned_grid, binned_grid)
mae_sum = np.zeros(shape)
bias_sum = np.zeros(shape)
flux_distribution = np.zeros(shape)
count = 0
files = sorted([os.path.join(flux_dir, f)
for f in os.listdir(flux_dir) if f.endswith('.npy')])
args_list = []
for fpath in files:
try:
ts_key = _ts_key(fpath)
except Exception:
continue
if ts_key not in lookup:
continue
abs_err, err = lookup[ts_key]
args_list.append((fpath, abs_err, err, bin_size))
print(f"Matched {len(args_list)} / {len(files)} flux maps")
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
futures = {executor.submit(_accumulate_flux_map, a): i
for i, a in enumerate(args_list)}
for future in tqdm(as_completed(futures), total=len(args_list),
desc="Accumulating flux-weighted errors"):
result = future.result()
if result is None:
continue
mae_c, bias_c, flux_c = result
mae_sum += mae_c
bias_sum += bias_c
flux_distribution += flux_c
count += 1
np.savez(cache_path, mae_sum=mae_sum, bias_sum=bias_sum,
flux_distribution=flux_distribution, count=np.array(count))
print(f"Saved → {cache_path}")
mask = _crop_mask(shape, bin_size)
mae = np.where(mask, mae_sum / flux_distribution, np.nan) if count > 0 else np.full(shape, np.nan)
bias = np.where(mask, bias_sum / flux_distribution, np.nan) if count > 0 else np.full(shape, np.nan)
return mae, bias, flux_distribution
# ---------------------------------------------------------------------------
# Plot
# ---------------------------------------------------------------------------
def _bin_grid(grid: np.ndarray, bin_size: int) -> np.ndarray:
if bin_size == 1:
return grid
h, w = grid.shape
bh, bw = h // bin_size, w // bin_size
cropped = grid[:bh * bin_size, :bw * bin_size]
return np.nanmean(cropped.reshape(bh, bin_size, bw, bin_size), axis=(1, 3))
def plot_flux_weighted_heatmap(mae_grid: np.ndarray, bias_grid: np.ndarray,
weight_grid: np.ndarray, out_path: Path,
subtitle: str = "", bin_size: int = BIN_SIZE,
vmax_pct: int = VMAX_PERCENTILE):
setup_barlow_font()
text_color = "#111111"
theta = np.linspace(0, 2 * np.pi, 300)
# Grids are already pre-binned during accumulation — use directly
mae_b = mae_grid
bias_b = bias_grid
n_bins = mae_b.shape[0]
cy, cx = n_bins / 2, n_bins / 2
# Solar limb radius in binned-patch units
r_limb = SOLAR_RADIUS_PATCHES / bin_size
mae_vmax = np.nanpercentile(mae_b, vmax_pct)
mae_norm = plt.Normalize(vmin=0, vmax=mae_vmax)
bias_cap = np.nanpercentile(np.abs(bias_b), vmax_pct)
bias_norm = plt.Normalize(vmin=-bias_cap, vmax=bias_cap)
panels = [
(mae_b, r"Normalized Flux-Weighted MAE", Colormap('cmocean:thermal').to_mpl(), mae_norm),
(bias_b, r"Normalized Flux-Weighted MBE", Colormap('cmasher:fusion_r').to_mpl(), bias_norm),
# (np.log10(np.where(weight_b > 0, weight_b, np.nan)),
# r"log$_{10}$ Accumulated Flux", "viridis", None),
]
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
fig.patch.set_facecolor("white")
for ax, (grid, title, cmap, norm) in zip(axes, panels):
im = ax.imshow(grid, origin="lower", cmap=cmap, norm=norm,
interpolation="bicubic", extent=[0, n_bins, 0, n_bins])
cbar = fig.colorbar(im, ax=ax, shrink=0.82,norm=LogNorm(vmin=0, vmax=mae_vmax),)
cbar.ax.tick_params(labelsize=9, colors=text_color)
def _fmt(x, _):
m, e = f"{x:.2e}".split("e")
return f"{m}e{int(e)}"
cbar.ax.yaxis.set_major_formatter(plt.matplotlib.ticker.FuncFormatter(_fmt))
for lbl in cbar.ax.get_yticklabels():
lbl.set_fontfamily("Barlow")
lbl.set_fontsize(9)
lbl.set_color(text_color)
#cbar.set_label(title, fontsize=9, color=text_color, fontfamily="Barlow")
ax.plot(cx + r_limb * np.cos(theta), cy + r_limb * np.sin(theta),
color="#4488FF", linestyle="--", linewidth=1.2, alpha=0.8,
label=f"Solar Limb")
tick_bins = np.linspace(0, n_bins, 7)
tick_labels = [f"{int((t - n_bins / 2) * bin_size)}" for t in tick_bins]
ax.set_xticks(tick_bins); ax.set_xticklabels(tick_labels)
ax.set_yticks(tick_bins); ax.set_yticklabels(tick_labels)
ax.set_title(title, fontsize=10, color=text_color, fontfamily="Barlow",)
ax.set_xlabel("Solar X (ViT Patches From Center)", fontsize=9,
color=text_color, fontfamily="Barlow")
ax.set_ylabel("Solar Y (ViT Patches From Center)", fontsize=9,
color=text_color, fontfamily="Barlow")
ax.tick_params(labelsize=8, colors=text_color)
ax.legend(fontsize=7, facecolor="white", edgecolor="grey", loc="upper right",)
for spine in ax.spines.values():
spine.set_color(text_color)
plt.tight_layout()
plt.savefig(out_path, dpi=400, bbox_inches="tight", facecolor="white")
plt.show()
print(f"Saved → {out_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--flux_dir", default=FLUX_DIR)
parser.add_argument("--predictions_csv", default=PREDICTIONS_CSV)
parser.add_argument("--out_dir", default=str(OUT_DIR))
args = parser.parse_args()
out = Path(args.out_dir)
out.mkdir(parents=True, exist_ok=True)
df = load_predictions(args.predictions_csv)
mae, bias, weight = compute_flux_weighted_errors(
args.flux_dir, df, out / "flux_weighted_errors.npz"
)
plot_flux_weighted_heatmap(mae, bias, weight,
out / "performance_heatmap_all.png",
subtitle="All flares")