RhodWeo commited on
Commit
5e88f2a
·
verified ·
1 Parent(s): f47399c

eval: add self-contained evaluate.py with HF push support

Browse files
Files changed (1) hide show
  1. evaluate.py +395 -252
evaluate.py CHANGED
@@ -1,37 +1,40 @@
1
  """
2
- evaluate.py
3
- ===========
4
- Evaluate WEO-SAS/sen2sr models using the opensr-test benchmark suite, then
5
- update the HuggingFace model card Evaluation Results (model-index YAML) for
6
- the WEO-SAS/sen2sr repo.
7
-
8
- This script lives inside each branch of WEO-SAS/sen2sr and is meant to be run
9
- from the directory returned by snapshot_download():
10
-
11
- from huggingface_hub import snapshot_download
12
- local_dir = snapshot_download("WEO-SAS/sen2sr") # or specify revision
13
- import subprocess, sys
14
- subprocess.run([sys.executable, f"{local_dir}/evaluate.py", "--push"])
15
-
16
- Or from the command line after cloning/downloading:
17
-
18
- python evaluate.py --push --token hf_...
19
-
20
- Requirements
21
- ------------
22
- pip install opensr-test huggingface_hub sen2sr safetensors
23
-
24
- Outputs
25
- -------
26
- 1. A CSV file with per-sample metric values.
27
- 2. Updated model-index YAML in the WEO-SAS/sen2sr main-branch README.md,
28
- using the HuggingFace EvalResult / ModelCardData API.
29
-
30
- HF Evaluation Results format
31
- -----------------------------
32
- Each result is keyed by (task_type, dataset_type, metric_type) and indexed
33
- under the model variant name. Running this script from different variants
34
- accumulates results in the shared README on the main branch.
 
 
 
35
  """
36
 
37
  from __future__ import annotations
@@ -39,7 +42,6 @@ from __future__ import annotations
39
  import argparse
40
  import csv
41
  import json
42
- import os
43
  import sys
44
  from pathlib import Path
45
  from typing import Dict, List, Optional
@@ -49,74 +51,128 @@ import torch
49
 
50
 
51
  # ---------------------------------------------------------------------------
52
- # Metric metadata
53
  # ---------------------------------------------------------------------------
54
 
55
- METRIC_COLS = [
56
- "reflectance", "spectral", "spatial",
57
- "synthesis", "hallucination", "omission", "improvement",
58
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  DATASETS = ["naip", "spot", "venus", "spain_crops", "spain_urban"]
61
 
62
- # Human-readable names used in the HF model card
63
- DATASET_NAMES = {
64
- "naip": "NAIP",
65
- "spot": "SPOT",
66
- "venus": "Venus",
67
- "spain_crops": "Spain Crops",
68
- "spain_urban": "Spain Urban",
69
- }
70
-
71
- METRIC_NAMES = {
72
- "reflectance": "Reflectance Distance (L1)",
73
- "spectral": "Spectral Angle Distance",
74
- "spatial": "Phase Correlation Error",
75
- "synthesis": "Synthesis Score",
76
- "hallucination": "Hallucination Score",
77
- "omission": "Omission Score",
78
- "improvement": "Improvement Score",
79
  }
 
80
 
81
 
82
  # ---------------------------------------------------------------------------
83
  # Model loading
84
  # ---------------------------------------------------------------------------
85
 
86
- def load_model_from_local(local_dir: str):
87
- """Load the model from the snapshot directory."""
88
- config_path = os.path.join(local_dir, "config.json")
89
- with open(config_path) as f:
90
- config = json.load(f)
 
 
 
 
 
 
 
 
 
91
 
92
- if local_dir not in sys.path:
93
- sys.path.insert(0, local_dir)
94
 
95
  # Clear any cached module from a previous variant
96
  for mod in ["model", "sen2sr_pt", "predictor", "base"]:
97
  sys.modules.pop(mod, None)
98
 
99
- # Dynamically load model.py from the local dir
100
- import importlib.util
101
- spec = importlib.util.spec_from_file_location("model", os.path.join(local_dir, "model.py"))
102
- module = importlib.util.module_from_spec(spec)
103
- sys.modules["model"] = module
104
- spec.loader.exec_module(module)
105
-
106
- return module.Model(local_dir=local_dir), config
107
 
108
 
109
  # ---------------------------------------------------------------------------
110
- # Inference
111
  # ---------------------------------------------------------------------------
112
 
113
- def run_sr(model, lr_np: np.ndarray, in_channels: int) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  """
115
  Run SR on a single LR patch.
116
 
117
- lr_np : (C_avail, H, W) float32 in [0, 1]
118
- in_channels: channels the model expects
119
- Returns : (C_out, H*sf, W*sf) float32
120
  """
121
  C_avail = lr_np.shape[0]
122
 
@@ -128,23 +184,75 @@ def run_sr(model, lr_np: np.ndarray, in_channels: int) -> np.ndarray:
128
  else:
129
  inp = lr_np[:in_channels]
130
 
131
- return model.predict(inp)
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  # ---------------------------------------------------------------------------
135
  # Per-dataset evaluation
136
  # ---------------------------------------------------------------------------
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def evaluate_dataset(
139
  model,
140
- in_channels: int,
141
  dataset_name: str,
142
- max_samples: Optional[int] = None,
143
- verbose: bool = True,
144
  ) -> Dict[str, float]:
145
  """
146
- Evaluate model on one opensr-test dataset.
147
- Returns dict of metric_name → mean_value (nan if unavailable).
 
148
  """
149
  try:
150
  import opensr_test
@@ -154,137 +262,160 @@ def evaluate_dataset(
154
  try:
155
  dataset = opensr_test.load(dataset_name)
156
  except Exception as e:
157
- print(f" [WARN] Could not load '{dataset_name}': {e}")
158
  return {}
159
 
 
 
 
 
160
  metrics_obj = opensr_test.Metrics()
 
 
 
161
  accum: Dict[str, list] = {m: [] for m in METRIC_COLS}
162
- n = len(dataset) if max_samples is None else min(max_samples, len(dataset))
 
163
 
164
  for i in range(n):
165
- sample = dataset[i]
166
- lr = sample["lr"]
167
- hr = sample["hr"]
168
-
169
- if isinstance(lr, torch.Tensor):
170
- lr = lr.cpu().numpy()
171
- if isinstance(hr, torch.Tensor):
172
- hr = hr.cpu().numpy()
173
- lr = lr.astype(np.float32)
174
- hr = hr.astype(np.float32)
175
-
176
- if lr.ndim == 2:
177
- lr = lr[np.newaxis]
178
- if hr.ndim == 2:
179
- hr = hr[np.newaxis]
180
 
181
  try:
182
- sr = run_sr(model, lr, in_channels)
183
  except Exception as e:
184
  print(f" [WARN] SR failed on sample {i}: {e}")
185
  continue
186
 
 
 
 
 
 
 
 
 
 
 
 
187
  lr_t = torch.from_numpy(lr)
188
- sr_t = torch.from_numpy(sr[:lr_t.shape[0]])
189
  hr_t = torch.from_numpy(hr)
190
 
 
 
 
 
191
  try:
192
  result = metrics_obj.compute(lr=lr_t, sr=sr_t, hr=hr_t)
 
 
193
  except Exception as e:
194
  print(f" [WARN] Metrics failed on sample {i}: {e}")
195
  continue
196
 
197
- for m in METRIC_COLS:
198
- val = result.get(m)
199
  if val is not None:
200
  v = float(val.mean()) if hasattr(val, "mean") else float(val)
201
- accum[m].append(v)
202
 
203
- if verbose and (i + 1) % 10 == 0:
204
- print(f" {i+1}/{n}", end="\r")
205
-
206
- if verbose:
207
- print()
208
 
 
209
  return {m: float(np.mean(vs)) if vs else float("nan") for m, vs in accum.items()}
210
 
211
 
212
  # ---------------------------------------------------------------------------
213
- # HF model card update
214
  # ---------------------------------------------------------------------------
215
 
216
- def build_eval_results(
217
- variant: str,
218
- results: Dict[str, Dict[str, float]], # dataset → metric → value
219
- ) -> list:
220
- """
221
- Build a list of huggingface_hub.EvalResult objects for one variant.
222
-
223
- One EvalResult per (dataset × metric) combination.
224
- """
225
- from huggingface_hub import EvalResult
226
-
227
- eval_results = []
228
- for ds_name, metrics in results.items():
229
- for metric_name, value in metrics.items():
230
- if np.isnan(value):
231
- continue
232
- eval_results.append(
233
- EvalResult(
234
- task_type = "image-to-image",
235
- task_name = "Super-Resolution",
236
- dataset_type = f"opensr-test-{ds_name}",
237
- dataset_name = DATASET_NAMES.get(ds_name, ds_name),
238
- dataset_config = ds_name,
239
- metric_type = metric_name,
240
- metric_name = METRIC_NAMES.get(metric_name, metric_name),
241
- metric_value = round(value, 6),
242
- model_name = variant,
243
- )
244
- )
245
- return eval_results
246
-
247
-
248
- def update_model_card(
249
- variant: str,
250
- eval_results: list,
251
- repo_id: str = "WEO-SAS/sen2sr",
252
- token: Optional[str] = None,
253
- push: bool = False,
254
- ) -> None:
255
- """
256
- Load the model card from the HF main branch, merge/replace this variant's
257
- eval results, and optionally push back.
258
- """
259
- from huggingface_hub import ModelCard, ModelCardData
260
- from huggingface_hub.repocard_data import model_index_to_eval_results, eval_results_to_model_index
261
 
262
- print(f"\nLoading model card from {repo_id} (main)...")
263
- try:
264
- card = ModelCard.load(repo_id, token=token)
265
- except Exception as e:
266
- print(f" [WARN] Could not load card: {e}. Creating empty card.")
267
- card = ModelCard("---\n---\n")
268
 
269
- existing: list = card.data.eval_results or []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
- # Remove old entries for this variant, keep other variants
272
- kept = [r for r in existing if getattr(r, "model_name", None) != variant]
273
- merged = kept + eval_results
 
 
 
274
 
275
- card.data.eval_results = merged
 
 
 
 
 
 
276
 
277
- print(f" Model-index now has {len(merged)} EvalResult entries "
278
- f"({len(eval_results)} from '{variant}', {len(kept)} from other variants).")
279
 
280
- if push:
281
- print(f" Pushing updated card to {repo_id}...")
282
- card.push_to_hub(repo_id, token=token)
283
- print(" Done.")
284
- else:
285
- print(" --push not set; card not pushed. Pass --push to update HF.")
286
- print("\n--- model-index YAML preview ---")
287
- print(card.data.to_yaml())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
 
290
  # ---------------------------------------------------------------------------
@@ -292,116 +423,128 @@ def update_model_card(
292
  # ---------------------------------------------------------------------------
293
 
294
  def main():
295
- # Detect local_dir: script is inside the snapshot directory
296
- local_dir = str(Path(__file__).parent.resolve())
297
-
298
- parser = argparse.ArgumentParser(
299
- description="Evaluate WEO-SAS/sen2sr and update HF model card"
300
- )
301
  parser.add_argument(
302
- "--local-dir", default=local_dir,
303
- help="Path to the snapshot_download directory (default: script location)",
 
304
  )
305
  parser.add_argument(
306
  "--datasets", nargs="+", default=DATASETS, choices=DATASETS,
307
- help="Datasets to evaluate on (default: all 5)",
308
  )
309
  parser.add_argument(
310
  "--max-samples", type=int, default=None,
311
- help="Cap samples per dataset for a quick smoke-test",
312
  )
313
  parser.add_argument(
314
- "--output", default=None,
315
- help="CSV output path (default: sen2sr_<variant>_eval.csv in local_dir)",
316
  )
317
  parser.add_argument(
318
- "--repo-id", default="WEO-SAS/sen2sr",
319
- help="HuggingFace repo whose main-branch card to update",
320
  )
321
  parser.add_argument(
322
- "--token", default=os.environ.get("HF_TOKEN"),
323
- help="HuggingFace token (default: $HF_TOKEN)",
324
  )
325
  parser.add_argument(
326
- "--push", action="store_true",
327
- help="Push updated model card to HF after evaluation",
328
  )
329
  parser.add_argument(
330
- "--dry-run", action="store_true",
331
- help="Print model-index YAML preview without pushing",
 
 
 
 
332
  )
333
  args = parser.parse_args()
334
 
335
- # Load model + config
336
- print(f"Loading model from {args.local_dir} ...")
337
- model, config = load_model_from_local(args.local_dir)
338
- variant = config.get("variant", "unknown")
339
- in_channels = config.get("in_channels", 4)
340
- print(f"Variant : {variant}")
341
- print(f"In-ch : {in_channels}")
342
- print(f"Desc : {config.get('description', '')}")
343
 
344
- # CSV output path
345
- csv_path = args.output or os.path.join(args.local_dir, f"sen2sr_{variant}_eval.csv")
346
-
347
- # Evaluate
348
- all_results: Dict[str, Dict[str, float]] = {}
349
  rows = []
350
 
351
- for ds in args.datasets:
352
- print(f"\n[{variant}] Dataset: {ds}")
353
- metrics = evaluate_dataset(model, in_channels, ds, args.max_samples)
354
- if not metrics:
 
 
 
 
 
 
355
  continue
356
 
357
- all_results[ds] = metrics
358
- rows.append({"variant": variant, "dataset": ds, **metrics})
 
 
 
359
 
360
- print(f" {'Metric':<18} {'Value':>10}")
361
- print(f" {'-'*30}")
362
- for m in METRIC_COLS:
363
- arrow = "↑" if m in ("synthesis", "improvement") else "↓"
364
- print(f" {m:<18} {metrics.get(m, float('nan')):>9.4f} {arrow}")
 
 
 
 
 
365
 
366
  # Save CSV
367
  if rows:
368
  fieldnames = ["variant", "dataset"] + METRIC_COLS
369
- with open(csv_path, "w", newline="") as f:
370
  writer = csv.DictWriter(f, fieldnames=fieldnames)
371
  writer.writeheader()
372
  writer.writerows(rows)
373
- print(f"\nCSV saved: {csv_path}")
374
-
375
- # Build HF EvalResult objects
376
- if not all_results:
377
- print("No results to push.")
378
- return
379
-
380
- eval_results = build_eval_results(variant, all_results)
381
- print(f"\nBuilt {len(eval_results)} EvalResult entries for '{variant}'.")
382
-
383
- # Update model card (optionally push)
384
- update_model_card(
385
- variant = variant,
386
- eval_results = eval_results,
387
- repo_id = args.repo_id,
388
- token = args.token,
389
- push = args.push and not args.dry_run,
390
- )
391
 
392
- # Summary table (mean across datasets)
393
- print("\n" + "="*60)
394
- print(f"SUMMARY {variant} — mean across {list(all_results.keys())}")
395
- print("="*60)
396
- means = {
397
- m: float(np.nanmean([v[m] for v in all_results.values() if m in v]))
398
- for m in METRIC_COLS
399
- }
400
- print(f" {'Metric':<18} {'Mean':>10}")
401
- print(f" {'-'*30}")
402
- for m in METRIC_COLS:
403
- arrow = "" if m in ("synthesis", "improvement") else "↓"
404
- print(f" {m:<18} {means.get(m, float('nan')):>9.4f} {arrow}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
 
407
  if __name__ == "__main__":
 
1
  """
2
+ sen2sr_evaluate.py
3
+ ==================
4
+ Evaluate WEO-SAS/sen2sr variants using the opensr-test benchmark suite.
5
+
6
+ Metrics computed per variant × dataset:
7
+ - reflectance (↓) L1 distance — radiometric fidelity
8
+ - spectral (↓) Spectral Angle Distance colour consistency
9
+ - spatial (↓) Phase Correlation geometric stability
10
+ - synthesis (↑) High-frequency detail added
11
+ - hallucination(↓) False details not in HR
12
+ - omission () Real details missing from SR
13
+ - improvement (↑) Correct new details introduced
14
+
15
+ Usage
16
+ -----
17
+ pip install opensr-test huggingface_hub sen2sr safetensors rasterio
18
+
19
+ # Evaluate everything (RGBN-compatible datasets + variants)
20
+ python sen2sr_evaluate.py
21
+
22
+ # Specific variants and/or datasets
23
+ python sen2sr_evaluate.py --variants main mamba-rgbn-x4 --datasets naip spot
24
+
25
+ # Skip download if already cached
26
+ python sen2sr_evaluate.py --cache-dir ./model_cache
27
+
28
+ Notes
29
+ -----
30
+ - RGBN variants (main, lite-rgbn-x4, mamba-rgbn-x4) are evaluated on all
31
+ opensr-test datasets (NAIP, SPOT, Venus, Spain Crops, Spain Urban).
32
+ - Full-pipeline 10-band variants (lite-main, mamba-main) and RSWIR variants
33
+ (lite-rswir-x2, mamba-rswir-x2) require all 10 Sentinel-2 bands.
34
+ opensr-test only provides 4-band RGBN patches, so these variants use the
35
+ 4 RGBN bands for input and the remaining 6 channels are zero-padded.
36
+ For a fair evaluation of those variants, use your own 10-band Sentinel-2
37
+ tiles and call evaluate_custom() directly.
38
  """
39
 
40
  from __future__ import annotations
 
42
  import argparse
43
  import csv
44
  import json
 
45
  import sys
46
  from pathlib import Path
47
  from typing import Dict, List, Optional
 
51
 
52
 
53
  # ---------------------------------------------------------------------------
54
+ # Variant registry
55
  # ---------------------------------------------------------------------------
56
 
57
+ VARIANTS: Dict[str, dict] = {
58
+ "main": {
59
+ "repo_id": "WEO-SAS/sen2sr",
60
+ "revision": None,
61
+ "in_channels": 4,
62
+ "scale": 4,
63
+ "note": "SEN2SRLite RGBN 4x (CNN)",
64
+ },
65
+ "lite-rswir-x2": {
66
+ "repo_id": "WEO-SAS/sen2sr",
67
+ "revision": "lite-rswir-x2",
68
+ "in_channels": 10,
69
+ "scale": 2,
70
+ "note": "SEN2SRLite RSWIR 2x (CNN) — zero-pads channels 4-9",
71
+ },
72
+ "lite-main": {
73
+ "repo_id": "WEO-SAS/sen2sr",
74
+ "revision": "lite-main",
75
+ "in_channels": 10,
76
+ "scale": 4,
77
+ "note": "SEN2SRLite full 10-band 4x (CNN) — zero-pads channels 4-9",
78
+ },
79
+ "mamba-rgbn-x4": {
80
+ "repo_id": "WEO-SAS/sen2sr",
81
+ "revision": "mamba-rgbn-x4",
82
+ "in_channels": 4,
83
+ "scale": 4,
84
+ "note": "SEN2SR RGBN 4x (Mamba)",
85
+ },
86
+ "mamba-rswir-x2": {
87
+ "repo_id": "WEO-SAS/sen2sr",
88
+ "revision": "mamba-rswir-x2",
89
+ "in_channels": 10,
90
+ "scale": 2,
91
+ "note": "SEN2SR RSWIR 2x (Swin2SR) — zero-pads channels 4-9",
92
+ },
93
+ "mamba-main": {
94
+ "repo_id": "WEO-SAS/sen2sr",
95
+ "revision": "mamba-main",
96
+ "in_channels": 10,
97
+ "scale": 4,
98
+ "note": "SEN2SR full 10-band 4x (Mamba+Swin) — zero-pads channels 4-9",
99
+ },
100
+ "srresnet": {
101
+ "repo_id": "WEO-SAS/srresnet",
102
+ "revision": None,
103
+ "in_channels": 4,
104
+ "scale": 4,
105
+ "note": "SRResNet RGBN→RGB 4x (baseline)",
106
+ },
107
+ }
108
 
109
  DATASETS = ["naip", "spot", "venus", "spain_crops", "spain_urban"]
110
 
111
+ # Canonical output column names actual opensr_test.Metrics key
112
+ METRIC_MAP = {
113
+ "reflectance": "reflectance",
114
+ "spectral": "spectral",
115
+ "spatial": "spatial",
116
+ "synthesis": "synthesis",
117
+ "hallucination": "ha_metric",
118
+ "omission": "om_metric",
119
+ "improvement": "im_metric",
 
 
 
 
 
 
 
 
120
  }
121
+ METRIC_COLS = list(METRIC_MAP.keys())
122
 
123
 
124
  # ---------------------------------------------------------------------------
125
  # Model loading
126
  # ---------------------------------------------------------------------------
127
 
128
+ def load_model(variant: str, cache_dir: str, local_models_dir: Optional[str] = None):
129
+ """Load a WEO-SAS model variant from a local dir or by downloading from HF Hub."""
130
+ if local_models_dir:
131
+ local_dir = str(Path(local_models_dir) / variant)
132
+ if not Path(local_dir).is_dir():
133
+ raise FileNotFoundError(f"Model dir not found: {local_dir}")
134
+ else:
135
+ from huggingface_hub import snapshot_download
136
+ repo_id = VARIANTS[variant].get("repo_id", "WEO-SAS/sen2sr")
137
+ revision = VARIANTS[variant]["revision"]
138
+ kwargs = dict(repo_id=repo_id, local_dir=f"{cache_dir}/{variant}")
139
+ if revision:
140
+ kwargs["revision"] = revision
141
+ local_dir = snapshot_download(**kwargs)
142
 
143
+ sys.path.insert(0, local_dir)
 
144
 
145
  # Clear any cached module from a previous variant
146
  for mod in ["model", "sen2sr_pt", "predictor", "base"]:
147
  sys.modules.pop(mod, None)
148
 
149
+ from model import Model # noqa: PLC0415
150
+ return Model(local_dir=local_dir)
 
 
 
 
 
 
151
 
152
 
153
  # ---------------------------------------------------------------------------
154
+ # Inference helpers
155
  # ---------------------------------------------------------------------------
156
 
157
+ def _pad_to_multiple(arr: np.ndarray, multiple: int) -> tuple:
158
+ """Pad (C, H, W) to the next multiple of `multiple`; return (padded, orig_h, orig_w)."""
159
+ _, h, w = arr.shape
160
+ h_pad = ((h + multiple - 1) // multiple) * multiple
161
+ w_pad = ((w + multiple - 1) // multiple) * multiple
162
+ if h_pad == h and w_pad == w:
163
+ return arr, h, w
164
+ padded = np.zeros((arr.shape[0], h_pad, w_pad), dtype=arr.dtype)
165
+ padded[:, :h, :w] = arr
166
+ return padded, h, w
167
+
168
+
169
+ def run_sr(model, lr_np: np.ndarray, in_channels: int, scale: int = 4,
170
+ patch_size: int = 128) -> np.ndarray:
171
  """
172
  Run SR on a single LR patch.
173
 
174
+ lr_np : (C_avail, H, W) float32 in [0, 1] — opensr-test provides C=4 (RGBN)
175
+ Returns : (C_out, H*scale, W*scale) float32, cropped to exact expected size
 
176
  """
177
  C_avail = lr_np.shape[0]
178
 
 
184
  else:
185
  inp = lr_np[:in_channels]
186
 
187
+ # Pre-pad to patch_size so HardConstraint sees consistent LR↔SR sizes
188
+ orig_h, orig_w = inp.shape[1], inp.shape[2]
189
+ inp, _, _ = _pad_to_multiple(inp, patch_size)
190
+
191
+ sr = model.predict(inp)
192
+
193
+ # Crop to exact expected size based on original (unpadded) LR dimensions
194
+ h_out = orig_h * scale
195
+ w_out = orig_w * scale
196
+ return sr[:, :h_out, :w_out]
197
 
198
 
199
  # ---------------------------------------------------------------------------
200
  # Per-dataset evaluation
201
  # ---------------------------------------------------------------------------
202
 
203
+ def _save_comparison(
204
+ lr: np.ndarray,
205
+ sr: np.ndarray,
206
+ hr: np.ndarray,
207
+ path: Path,
208
+ title: str,
209
+ variant: str,
210
+ ) -> None:
211
+ try:
212
+ import matplotlib
213
+ matplotlib.use("Agg")
214
+ import matplotlib.pyplot as plt
215
+ from skimage.transform import resize as sk_resize
216
+
217
+ def to_rgb(arr):
218
+ rgb = np.clip(arr[:3].transpose(1, 2, 0), 0, 1)
219
+ return (rgb * 255).astype(np.uint8)
220
+
221
+ hr_h, hr_w = hr.shape[1], hr.shape[2]
222
+ lr_big = sk_resize(to_rgb(lr), (hr_h, hr_w), order=1, preserve_range=True).astype(np.uint8)
223
+ sr_rgb = to_rgb(sr)
224
+ hr_rgb = to_rgb(hr)
225
+
226
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4))
227
+ for ax, img, label in zip(
228
+ axes,
229
+ [lr_big, sr_rgb, hr_rgb],
230
+ ["LR (bicubic)", f"SR ({variant})", "HR (reference)"],
231
+ ):
232
+ ax.imshow(img)
233
+ ax.set_title(label, fontsize=10)
234
+ ax.axis("off")
235
+ fig.suptitle(f"{variant} — {title}", fontsize=12, fontweight="bold")
236
+ plt.tight_layout()
237
+ path.parent.mkdir(parents=True, exist_ok=True)
238
+ plt.savefig(path, dpi=100, bbox_inches="tight")
239
+ plt.close(fig)
240
+ print(f" Saved image: {path.name}")
241
+ except Exception as e:
242
+ print(f" [WARN] Could not save image: {e}")
243
+
244
+
245
  def evaluate_dataset(
246
  model,
247
+ variant: str,
248
  dataset_name: str,
249
+ max_samples: Optional[int] = None,
250
+ save_images_dir: Optional[Path] = None,
251
  ) -> Dict[str, float]:
252
  """
253
+ Run a variant against one opensr-test dataset and return mean metrics.
254
+
255
+ Returns a dict mapping metric name → mean value, or empty dict on error.
256
  """
257
  try:
258
  import opensr_test
 
262
  try:
263
  dataset = opensr_test.load(dataset_name)
264
  except Exception as e:
265
+ print(f" [WARN] Could not load dataset '{dataset_name}': {e}")
266
  return {}
267
 
268
+ # opensr-test dataset is a dict: {"L2A": (N,C,H,W) uint16, "HRharm": (N,C,H,W) uint16}
269
+ lr_all = dataset["L2A"]
270
+ hr_all = dataset["HRharm"]
271
+
272
  metrics_obj = opensr_test.Metrics()
273
+ vinfo = VARIANTS[variant]
274
+ in_ch = vinfo["in_channels"]
275
+ scale = vinfo["scale"]
276
  accum: Dict[str, list] = {m: [] for m in METRIC_COLS}
277
+ n = lr_all.shape[0] if max_samples is None else min(max_samples, lr_all.shape[0])
278
+ saved_image = False
279
 
280
  for i in range(n):
281
+ lr = lr_all[i].astype(np.float32) / 10000.0 # (C, H, W) → [0, 1]
282
+ hr = hr_all[i].astype(np.float32) / 10000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  try:
285
+ sr = run_sr(model, lr, in_ch, scale)
286
  except Exception as e:
287
  print(f" [WARN] SR failed on sample {i}: {e}")
288
  continue
289
 
290
+ # For x2 models on 4x datasets: SR is half the HR size — skip metrics
291
+ if sr.shape[1] != hr.shape[1] or sr.shape[2] != hr.shape[2]:
292
+ if i == 0:
293
+ print(f" [SKIP] SR {sr.shape} != HR {hr.shape} — scale mismatch, skipping dataset")
294
+ continue
295
+
296
+ if save_images_dir and not saved_image:
297
+ img_path = save_images_dir / f"{variant}_{dataset_name}.png"
298
+ _save_comparison(lr, sr, hr, img_path, dataset_name, variant)
299
+ saved_image = True
300
+
301
  lr_t = torch.from_numpy(lr)
302
+ sr_t = torch.from_numpy(sr)
303
  hr_t = torch.from_numpy(hr)
304
 
305
+ # Align channels: metrics require lr/sr/hr to have the same count
306
+ min_ch = min(lr_t.shape[0], sr_t.shape[0], hr_t.shape[0])
307
+ lr_t, sr_t, hr_t = lr_t[:min_ch], sr_t[:min_ch], hr_t[:min_ch]
308
+
309
  try:
310
  result = metrics_obj.compute(lr=lr_t, sr=sr_t, hr=hr_t)
311
+ if not isinstance(result, dict):
312
+ result = vars(result) if hasattr(result, "__dict__") else {}
313
  except Exception as e:
314
  print(f" [WARN] Metrics failed on sample {i}: {e}")
315
  continue
316
 
317
+ for col, api_key in METRIC_MAP.items():
318
+ val = result.get(api_key)
319
  if val is not None:
320
  v = float(val.mean()) if hasattr(val, "mean") else float(val)
321
+ accum[col].append(v)
322
 
323
+ if (i + 1) % 10 == 0:
324
+ print(f" {i+1}/{n} samples processed", end="\r")
 
 
 
325
 
326
+ print()
327
  return {m: float(np.mean(vs)) if vs else float("nan") for m, vs in accum.items()}
328
 
329
 
330
  # ---------------------------------------------------------------------------
331
+ # HF output helpers
332
  # ---------------------------------------------------------------------------
333
 
334
+ def build_eval_json(rows: list) -> dict:
335
+ """Build eval_results.json dict from accumulated CSV rows."""
336
+ from collections import defaultdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ per_dataset: dict = {}
339
+ agg: dict = defaultdict(lambda: defaultdict(list))
 
 
 
 
340
 
341
+ for row in rows:
342
+ v = row["variant"]
343
+ ds = row["dataset"]
344
+ per_dataset.setdefault(ds, {})
345
+ m_vals = {}
346
+ for m in METRIC_COLS:
347
+ val = row.get(m, float("nan"))
348
+ m_vals[m] = val
349
+ if not (isinstance(val, float) and np.isnan(val)):
350
+ agg[v][m].append(val)
351
+ per_dataset[ds][v] = m_vals
352
+
353
+ aggregate = {
354
+ v: {m: float(np.mean(vs)) if vs else float("nan") for m, vs in metrics.items()}
355
+ for v, metrics in agg.items()
356
+ }
357
 
358
+ variants_meta = {
359
+ v: {"note": VARIANTS[v]["note"], "in_channels": VARIANTS[v]["in_channels"],
360
+ "scale": VARIANTS[v]["scale"]}
361
+ for v in VARIANTS
362
+ if v in agg
363
+ }
364
 
365
+ return {
366
+ "eval_type": "super_resolution",
367
+ "model_name": "SEN2SR",
368
+ "variants": variants_meta,
369
+ "per_dataset": per_dataset,
370
+ "aggregate": aggregate,
371
+ }
372
 
 
 
373
 
374
+ def push_to_hf(
375
+ eval_json: dict,
376
+ images_dir: Optional[Path],
377
+ csv_path: str,
378
+ hf_token: str,
379
+ commit_message: str = "eval: update benchmark results",
380
+ ) -> None:
381
+ from huggingface_hub import HfApi
382
+ api = HfApi(token=hf_token)
383
+
384
+ repo_id = "WEO-SAS/sen2sr"
385
+
386
+ # Push eval_results.json
387
+ eval_str = json.dumps(eval_json, indent=2)
388
+ api.upload_file(
389
+ path_or_fileobj=eval_str.encode(),
390
+ path_in_repo="eval_results.json",
391
+ repo_id=repo_id,
392
+ repo_type="model",
393
+ commit_message=commit_message,
394
+ )
395
+ print("Pushed eval_results.json")
396
+
397
+ # Push CSV
398
+ if Path(csv_path).exists():
399
+ api.upload_file(
400
+ path_or_fileobj=csv_path,
401
+ path_in_repo=f"eval/{Path(csv_path).name}",
402
+ repo_id=repo_id,
403
+ repo_type="model",
404
+ commit_message=commit_message,
405
+ )
406
+ print(f"Pushed eval/{Path(csv_path).name}")
407
+
408
+ # Push images
409
+ if images_dir and images_dir.exists():
410
+ for img_path in sorted(images_dir.glob("*.png")):
411
+ api.upload_file(
412
+ path_or_fileobj=str(img_path),
413
+ path_in_repo=f"eval_images/{img_path.name}",
414
+ repo_id=repo_id,
415
+ repo_type="model",
416
+ commit_message=commit_message,
417
+ )
418
+ print(f"Pushed eval_images/{img_path.name}")
419
 
420
 
421
  # ---------------------------------------------------------------------------
 
423
  # ---------------------------------------------------------------------------
424
 
425
  def main():
426
+ parser = argparse.ArgumentParser(description="Evaluate WEO-SAS/sen2sr variants")
 
 
 
 
 
427
  parser.add_argument(
428
+ "--variants", nargs="+", default=list(VARIANTS.keys()),
429
+ choices=list(VARIANTS.keys()),
430
+ help="Variants to evaluate (default: all)",
431
  )
432
  parser.add_argument(
433
  "--datasets", nargs="+", default=DATASETS, choices=DATASETS,
434
+ help="Datasets to use (default: all)",
435
  )
436
  parser.add_argument(
437
  "--max-samples", type=int, default=None,
438
+ help="Cap samples per dataset (useful for a quick smoke-test)",
439
  )
440
  parser.add_argument(
441
+ "--cache-dir", default="./sen2sr_model_cache",
442
+ help="Directory to cache downloaded model weights",
443
  )
444
  parser.add_argument(
445
+ "--local-models-dir", default=None,
446
+ help="Use pre-downloaded models instead of HF Hub (subdir per variant: main/, lite-main/, etc.)",
447
  )
448
  parser.add_argument(
449
+ "--output", default="sen2sr_eval_results.csv",
450
+ help="Output CSV path",
451
  )
452
  parser.add_argument(
453
+ "--images-dir", default="./eval_images",
454
+ help="Directory for visual comparison PNG files",
455
  )
456
  parser.add_argument(
457
+ "--hf-token", default=None,
458
+ help="HuggingFace write token (or set HF_TOKEN env var)",
459
+ )
460
+ parser.add_argument(
461
+ "--no-push", action="store_true",
462
+ help="Skip HF push (dry-run)",
463
  )
464
  args = parser.parse_args()
465
 
466
+ import os
467
+ hf_token = args.hf_token or os.environ.get("HF_TOKEN")
468
+ images_dir = Path(args.images_dir)
469
+ images_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
470
 
471
+ Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
 
 
 
 
472
  rows = []
473
 
474
+ for variant in args.variants:
475
+ print(f"\n{'='*60}")
476
+ print(f"Variant: {variant} ({VARIANTS[variant]['note']})")
477
+ print(f"{'='*60}")
478
+
479
+ try:
480
+ print(" Loading model...")
481
+ model = load_model(variant, args.cache_dir, args.local_models_dir)
482
+ except Exception as e:
483
+ print(f" [ERROR] Could not load model: {e}")
484
  continue
485
 
486
+ for ds in args.datasets:
487
+ print(f" Dataset: {ds}")
488
+ metrics = evaluate_dataset(model, variant, ds, args.max_samples, images_dir)
489
+ if not metrics:
490
+ continue
491
 
492
+ row = {"variant": variant, "dataset": ds}
493
+ row.update(metrics)
494
+ rows.append(row)
495
+
496
+ # Pretty-print
497
+ print(f" {'Metric':<16} {'Value':>10}")
498
+ print(f" {'-'*28}")
499
+ for m in METRIC_COLS:
500
+ arrow = "↑" if m in ("synthesis", "improvement") else "↓"
501
+ print(f" {m:<16} {metrics.get(m, float('nan')):>9.4f} {arrow}")
502
 
503
  # Save CSV
504
  if rows:
505
  fieldnames = ["variant", "dataset"] + METRIC_COLS
506
+ with open(args.output, "w", newline="") as f:
507
  writer = csv.DictWriter(f, fieldnames=fieldnames)
508
  writer.writeheader()
509
  writer.writerows(rows)
510
+ print(f"\nResults saved to: {args.output}")
511
+ else:
512
+ print("\nNo results to save.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
+ # Summary table
515
+ if rows:
516
+ print("\n" + "="*60)
517
+ print("SUMMARY — mean across all datasets")
518
+ print("="*60)
519
+ from collections import defaultdict
520
+ agg: dict = defaultdict(lambda: defaultdict(list))
521
+ for row in rows:
522
+ for m in METRIC_COLS:
523
+ v = row.get(m, float("nan"))
524
+ if not np.isnan(v):
525
+ agg[row["variant"]][m].append(v)
526
+
527
+ header = f"{'Variant':<20}" + "".join(f"{m[:8]:>11}" for m in METRIC_COLS)
528
+ print(header)
529
+ print("-" * len(header))
530
+ for variant in args.variants:
531
+ if variant not in agg:
532
+ continue
533
+ vals = "".join(
534
+ f"{np.mean(agg[variant].get(m, [float('nan')])):>11.4f}"
535
+ for m in METRIC_COLS
536
+ )
537
+ print(f"{variant:<20}{vals}")
538
+
539
+ # Push to HF
540
+ if rows and not args.no_push:
541
+ if not hf_token:
542
+ print("\n[WARN] No HF token — skipping push. Pass --hf-token or set HF_TOKEN.")
543
+ else:
544
+ print("\nPushing results to HuggingFace...")
545
+ eval_json = build_eval_json(rows)
546
+ push_to_hf(eval_json, images_dir, args.output, hf_token)
547
+ print("Done.")
548
 
549
 
550
  if __name__ == "__main__":