File size: 5,259 Bytes
80ef3b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
"""Build the RQ4 task-rank map from released paper table TeX.

This uses only the Python standard library plus the small PDF helper bundled in
this repository. It does not require raw data, feature caches, or matplotlib.
"""

from __future__ import annotations

import re
from pathlib import Path

from simple_pdf import PdfCanvas, mix


ROOT = Path(__file__).resolve().parents[1]
TABLE_DIR = ROOT / "paper_outputs" / "tables"
OUT = ROOT / "paper_outputs" / "figures" / "fig_task_rank_map.pdf"

MODELS = [
    "FireWx-FM ref.",
    "Prithvi-WxC",
    "Aurora",
    "ClimaX",
    "StormCast",
    "DLWP",
    "FCN",
    "FengWu",
    "FuXi",
    "Pangu-Weather",
    "AlphaEarth",
]

DISPLAY_MODELS = [
    "FireWx-FM ref.",
    "Prithvi-WxC",
    "Aurora",
    "ClimaX",
    "StormCast",
    "DLWP",
    "FCN",
    "FengWu",
    "FuXi",
    "Pangu-Weather",
    "AlphaEarth",
]


def means_from_row(block: str, label: str, occurrence: int = 0) -> list[float]:
    starts = [m.start() for m in re.finditer(re.escape(label), block)]
    if len(starts) <= occurrence:
        raise ValueError(f"Missing row {label!r} occurrence {occurrence}")
    start = starts[occurrence]
    end = block.find(r"\\", start)
    if end < 0:
        raise ValueError(f"Missing row terminator after {label!r}")
    row = block[start:end]
    values = [float(x) for x in re.findall(r"\\ms\{(-?\d+(?:\.\d+)?)\}\{", row)]
    if not values:
        raise ValueError(f"No values found for row {label!r}")
    return values


def rank_values(values: list[float], higher_is_better: bool) -> list[int]:
    order = sorted(range(len(values)), key=lambda i: values[i], reverse=higher_is_better)
    ranks = [0] * len(values)
    for rank, idx in enumerate(order, start=1):
        ranks[idx] = rank
    return ranks


def rank_color(rank: int, n_cols: int) -> tuple[float, float, float]:
    t = (n_cols - rank) / max(1, n_cols - 1)
    if t <= 0.5:
        return mix((0.93, 0.95, 0.94), (0.55, 0.78, 0.75), t / 0.5)
    return mix((0.55, 0.78, 0.75), (0.05, 0.40, 0.42), (t - 0.5) / 0.5)


def fmt_value(value: float) -> str:
    return f"{value:.2f}" if abs(value) >= 1 else f"{value:.3f}"


def main() -> None:
    primary = (TABLE_DIR / "tab_primary_results.tex").read_text()
    supporting = (TABLE_DIR / "tab_supporting_results.tex").read_text()

    primary_rows = {label: means_from_row(primary, label) for label in MODELS}
    supporting_top = {label: means_from_row(supporting, label, occurrence=0) for label in MODELS}
    supporting_bottom = {label: means_from_row(supporting, label, occurrence=1) for label in MODELS}

    tasks = [
        ("Occupancy", "Union F1 (%)", "higher better", [primary_rows[m][2] for m in MODELS], True),
        ("Fire spread", "AP (%)", "higher better", [primary_rows[m][5] for m in MODELS], True),
        ("Burned area", "log-RMSE", "lower better", [supporting_top[m][0] for m in MODELS], False),
        ("Analog retrieval", "nDCG@10", "higher better", [supporting_top[m][3] for m in MODELS], True),
        ("Smoke PM2.5", "RMSE", "lower better", [supporting_bottom[m][0] for m in MODELS], False),
        ("Extreme heat", "RMSE-C", "lower better", [supporting_bottom[m][3] for m in MODELS], False),
    ]

    n_rows = len(tasks)
    n_cols = len(MODELS)
    c = PdfCanvas(width=1120, height=430)
    c.rect(0, 0, c.width, c.height, fill=(1, 1, 1))

    x0, y0 = 108, 90
    cell_w, cell_h = 86, 42
    grid_top = y0 + n_rows * cell_h

    for j, model in enumerate(DISPLAY_MODELS):
        c.text(x0 + j * cell_w + cell_w / 2, grid_top + 34, model, size=8.7, align="center", bold=True, color=(0.12, 0.14, 0.16))

    for i, (task, metric, direction, values, higher) in enumerate(tasks):
        y = grid_top - (i + 1) * cell_h
        c.text(12, y + 25, task, size=7.7, bold=True, color=(0.12, 0.14, 0.16))
        c.text(12, y + 14, metric, size=7.1, bold=True, color=(0.12, 0.14, 0.16))
        c.text(12, y + 3, direction, size=6.4, color=(0.42, 0.44, 0.46))
        ranks = rank_values(values, higher)
        for j, (rank, value) in enumerate(zip(ranks, values)):
            x = x0 + j * cell_w
            color = rank_color(rank, n_cols)
            text_color = (1, 1, 1) if rank <= 2 else (0.07, 0.09, 0.11)
            c.rect(x, y, cell_w, cell_h, fill=color, stroke=(1, 1, 1), lw=0.8)
            c.text(x + cell_w / 2, y + 24, f"#{rank}", size=11.2, align="center", bold=True, color=text_color)
            c.text(x + cell_w / 2, y + 9, fmt_value(value), size=7.0, align="center", color=text_color)

    c.rect(x0, y0, cell_w * n_cols, cell_h * n_rows, stroke=(0.20, 0.22, 0.24), lw=0.8)

    key_x, key_y = x0 + cell_w * n_cols - 220, 38
    c.text(key_x + 110, key_y + 25, "within-row rank", size=9.0, bold=True, align="center", color=(0.24, 0.25, 0.26))
    for i in range(80):
        color = rank_color(n_cols - int(i / 80 * (n_cols - 1)), n_cols)
        c.rect(key_x + i * 2.5, key_y + 8, 2.6, 10, fill=color)
    c.text(key_x, key_y - 4, f"rank {n_cols}", size=7.0, color=(0.25, 0.26, 0.27))
    c.text(key_x + 200, key_y - 4, "rank 1", size=7.0, align="right", color=(0.25, 0.26, 0.27))

    OUT.parent.mkdir(parents=True, exist_ok=True)
    c.save(OUT)


if __name__ == "__main__":
    main()