Wildfire-FM / scripts /build_task_rank_map.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
"""Build the RQ4 task-rank map from released paper table TeX.
This uses only the Python standard library plus the small PDF helper bundled in
this repository. It does not require raw data, feature caches, or matplotlib.
"""
from __future__ import annotations
import re
from pathlib import Path
from simple_pdf import PdfCanvas, mix
ROOT = Path(__file__).resolve().parents[1]
TABLE_DIR = ROOT / "paper_outputs" / "tables"
OUT = ROOT / "paper_outputs" / "figures" / "fig_task_rank_map.pdf"
MODELS = [
"FireWx-FM ref.",
"Prithvi-WxC",
"Aurora",
"ClimaX",
"StormCast",
"DLWP",
"FCN",
"FengWu",
"FuXi",
"Pangu-Weather",
"AlphaEarth",
]
DISPLAY_MODELS = [
"FireWx-FM ref.",
"Prithvi-WxC",
"Aurora",
"ClimaX",
"StormCast",
"DLWP",
"FCN",
"FengWu",
"FuXi",
"Pangu-Weather",
"AlphaEarth",
]
def means_from_row(block: str, label: str, occurrence: int = 0) -> list[float]:
starts = [m.start() for m in re.finditer(re.escape(label), block)]
if len(starts) <= occurrence:
raise ValueError(f"Missing row {label!r} occurrence {occurrence}")
start = starts[occurrence]
end = block.find(r"\\", start)
if end < 0:
raise ValueError(f"Missing row terminator after {label!r}")
row = block[start:end]
values = [float(x) for x in re.findall(r"\\ms\{(-?\d+(?:\.\d+)?)\}\{", row)]
if not values:
raise ValueError(f"No values found for row {label!r}")
return values
def rank_values(values: list[float], higher_is_better: bool) -> list[int]:
order = sorted(range(len(values)), key=lambda i: values[i], reverse=higher_is_better)
ranks = [0] * len(values)
for rank, idx in enumerate(order, start=1):
ranks[idx] = rank
return ranks
def rank_color(rank: int, n_cols: int) -> tuple[float, float, float]:
t = (n_cols - rank) / max(1, n_cols - 1)
if t <= 0.5:
return mix((0.93, 0.95, 0.94), (0.55, 0.78, 0.75), t / 0.5)
return mix((0.55, 0.78, 0.75), (0.05, 0.40, 0.42), (t - 0.5) / 0.5)
def fmt_value(value: float) -> str:
return f"{value:.2f}" if abs(value) >= 1 else f"{value:.3f}"
def main() -> None:
primary = (TABLE_DIR / "tab_primary_results.tex").read_text()
supporting = (TABLE_DIR / "tab_supporting_results.tex").read_text()
primary_rows = {label: means_from_row(primary, label) for label in MODELS}
supporting_top = {label: means_from_row(supporting, label, occurrence=0) for label in MODELS}
supporting_bottom = {label: means_from_row(supporting, label, occurrence=1) for label in MODELS}
tasks = [
("Occupancy", "Union F1 (%)", "higher better", [primary_rows[m][2] for m in MODELS], True),
("Fire spread", "AP (%)", "higher better", [primary_rows[m][5] for m in MODELS], True),
("Burned area", "log-RMSE", "lower better", [supporting_top[m][0] for m in MODELS], False),
("Analog retrieval", "nDCG@10", "higher better", [supporting_top[m][3] for m in MODELS], True),
("Smoke PM2.5", "RMSE", "lower better", [supporting_bottom[m][0] for m in MODELS], False),
("Extreme heat", "RMSE-C", "lower better", [supporting_bottom[m][3] for m in MODELS], False),
]
n_rows = len(tasks)
n_cols = len(MODELS)
c = PdfCanvas(width=1120, height=430)
c.rect(0, 0, c.width, c.height, fill=(1, 1, 1))
x0, y0 = 108, 90
cell_w, cell_h = 86, 42
grid_top = y0 + n_rows * cell_h
for j, model in enumerate(DISPLAY_MODELS):
c.text(x0 + j * cell_w + cell_w / 2, grid_top + 34, model, size=8.7, align="center", bold=True, color=(0.12, 0.14, 0.16))
for i, (task, metric, direction, values, higher) in enumerate(tasks):
y = grid_top - (i + 1) * cell_h
c.text(12, y + 25, task, size=7.7, bold=True, color=(0.12, 0.14, 0.16))
c.text(12, y + 14, metric, size=7.1, bold=True, color=(0.12, 0.14, 0.16))
c.text(12, y + 3, direction, size=6.4, color=(0.42, 0.44, 0.46))
ranks = rank_values(values, higher)
for j, (rank, value) in enumerate(zip(ranks, values)):
x = x0 + j * cell_w
color = rank_color(rank, n_cols)
text_color = (1, 1, 1) if rank <= 2 else (0.07, 0.09, 0.11)
c.rect(x, y, cell_w, cell_h, fill=color, stroke=(1, 1, 1), lw=0.8)
c.text(x + cell_w / 2, y + 24, f"#{rank}", size=11.2, align="center", bold=True, color=text_color)
c.text(x + cell_w / 2, y + 9, fmt_value(value), size=7.0, align="center", color=text_color)
c.rect(x0, y0, cell_w * n_cols, cell_h * n_rows, stroke=(0.20, 0.22, 0.24), lw=0.8)
key_x, key_y = x0 + cell_w * n_cols - 220, 38
c.text(key_x + 110, key_y + 25, "within-row rank", size=9.0, bold=True, align="center", color=(0.24, 0.25, 0.26))
for i in range(80):
color = rank_color(n_cols - int(i / 80 * (n_cols - 1)), n_cols)
c.rect(key_x + i * 2.5, key_y + 8, 2.6, 10, fill=color)
c.text(key_x, key_y - 4, f"rank {n_cols}", size=7.0, color=(0.25, 0.26, 0.27))
c.text(key_x + 200, key_y - 4, "rank 1", size=7.0, align="right", color=(0.25, 0.26, 0.27))
OUT.parent.mkdir(parents=True, exist_ok=True)
c.save(OUT)
if __name__ == "__main__":
main()