RhodWeo commited on
Commit
a48a226
·
1 Parent(s): d04f678

Add evaluate.py and variant field in config.json

Browse files
Files changed (2) hide show
  1. config.json +3 -2
  2. evaluate.py +408 -0
config.json CHANGED
@@ -15,5 +15,6 @@
15
  "hard_constraint_bands": null,
16
  "weights_file": "model.safetensor",
17
  "hard_constraint_file": "hard_constraint.safetensor",
18
- "description": "SEN2SRLite RGBN x4: Sentinel-2 RGBN 10m -> 2.5m super-resolution (4x, CNN)"
19
- }
 
 
15
  "hard_constraint_bands": null,
16
  "weights_file": "model.safetensor",
17
  "hard_constraint_file": "hard_constraint.safetensor",
18
+ "description": "SEN2SRLite RGBN x4: Sentinel-2 RGBN 10m -> 2.5m super-resolution (4x, CNN)",
19
+ "variant": "main"
20
+ }
evaluate.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ evaluate.py
3
+ ===========
4
+ Evaluate WEO-SAS/sen2sr models using the opensr-test benchmark suite, then
5
+ update the HuggingFace model card Evaluation Results (model-index YAML) for
6
+ the WEO-SAS/sen2sr repo.
7
+
8
+ This script lives inside each branch of WEO-SAS/sen2sr and is meant to be run
9
+ from the directory returned by snapshot_download():
10
+
11
+ from huggingface_hub import snapshot_download
12
+ local_dir = snapshot_download("WEO-SAS/sen2sr") # or specify revision
13
+ import subprocess, sys
14
+ subprocess.run([sys.executable, f"{local_dir}/evaluate.py", "--push"])
15
+
16
+ Or from the command line after cloning/downloading:
17
+
18
+ python evaluate.py --push --token hf_...
19
+
20
+ Requirements
21
+ ------------
22
+ pip install opensr-test huggingface_hub sen2sr safetensors
23
+
24
+ Outputs
25
+ -------
26
+ 1. A CSV file with per-sample metric values.
27
+ 2. Updated model-index YAML in the WEO-SAS/sen2sr main-branch README.md,
28
+ using the HuggingFace EvalResult / ModelCardData API.
29
+
30
+ HF Evaluation Results format
31
+ -----------------------------
32
+ Each result is keyed by (task_type, dataset_type, metric_type) and indexed
33
+ under the model variant name. Running this script from different variants
34
+ accumulates results in the shared README on the main branch.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import argparse
40
+ import csv
41
+ import json
42
+ import os
43
+ import sys
44
+ from pathlib import Path
45
+ from typing import Dict, List, Optional
46
+
47
+ import numpy as np
48
+ import torch
49
+
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Metric metadata
53
+ # ---------------------------------------------------------------------------
54
+
55
+ METRIC_COLS = [
56
+ "reflectance", "spectral", "spatial",
57
+ "synthesis", "hallucination", "omission", "improvement",
58
+ ]
59
+
60
+ DATASETS = ["naip", "spot", "venus", "spain_crops", "spain_urban"]
61
+
62
+ # Human-readable names used in the HF model card
63
+ DATASET_NAMES = {
64
+ "naip": "NAIP",
65
+ "spot": "SPOT",
66
+ "venus": "Venus",
67
+ "spain_crops": "Spain Crops",
68
+ "spain_urban": "Spain Urban",
69
+ }
70
+
71
+ METRIC_NAMES = {
72
+ "reflectance": "Reflectance Distance (L1)",
73
+ "spectral": "Spectral Angle Distance",
74
+ "spatial": "Phase Correlation Error",
75
+ "synthesis": "Synthesis Score",
76
+ "hallucination": "Hallucination Score",
77
+ "omission": "Omission Score",
78
+ "improvement": "Improvement Score",
79
+ }
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Model loading
84
+ # ---------------------------------------------------------------------------
85
+
86
+ def load_model_from_local(local_dir: str):
87
+ """Load the model from the snapshot directory."""
88
+ config_path = os.path.join(local_dir, "config.json")
89
+ with open(config_path) as f:
90
+ config = json.load(f)
91
+
92
+ if local_dir not in sys.path:
93
+ sys.path.insert(0, local_dir)
94
+
95
+ # Clear any cached module from a previous variant
96
+ for mod in ["model", "sen2sr_pt", "predictor", "base"]:
97
+ sys.modules.pop(mod, None)
98
+
99
+ # Dynamically load model.py from the local dir
100
+ import importlib.util
101
+ spec = importlib.util.spec_from_file_location("model", os.path.join(local_dir, "model.py"))
102
+ module = importlib.util.module_from_spec(spec)
103
+ sys.modules["model"] = module
104
+ spec.loader.exec_module(module)
105
+
106
+ return module.Model(local_dir=local_dir), config
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Inference
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def run_sr(model, lr_np: np.ndarray, in_channels: int) -> np.ndarray:
114
+ """
115
+ Run SR on a single LR patch.
116
+
117
+ lr_np : (C_avail, H, W) float32 in [0, 1]
118
+ in_channels: channels the model expects
119
+ Returns : (C_out, H*sf, W*sf) float32
120
+ """
121
+ C_avail = lr_np.shape[0]
122
+
123
+ if in_channels == C_avail:
124
+ inp = lr_np
125
+ elif in_channels > C_avail:
126
+ pad = np.zeros((in_channels - C_avail,) + lr_np.shape[1:], dtype=np.float32)
127
+ inp = np.concatenate([lr_np, pad], axis=0)
128
+ else:
129
+ inp = lr_np[:in_channels]
130
+
131
+ return model.predict(inp)
132
+
133
+
134
+ # ---------------------------------------------------------------------------
135
+ # Per-dataset evaluation
136
+ # ---------------------------------------------------------------------------
137
+
138
+ def evaluate_dataset(
139
+ model,
140
+ in_channels: int,
141
+ dataset_name: str,
142
+ max_samples: Optional[int] = None,
143
+ verbose: bool = True,
144
+ ) -> Dict[str, float]:
145
+ """
146
+ Evaluate model on one opensr-test dataset.
147
+ Returns dict of metric_name → mean_value (nan if unavailable).
148
+ """
149
+ try:
150
+ import opensr_test
151
+ except ImportError:
152
+ raise ImportError("pip install opensr-test")
153
+
154
+ try:
155
+ dataset = opensr_test.load(dataset_name)
156
+ except Exception as e:
157
+ print(f" [WARN] Could not load '{dataset_name}': {e}")
158
+ return {}
159
+
160
+ metrics_obj = opensr_test.Metrics()
161
+ accum: Dict[str, list] = {m: [] for m in METRIC_COLS}
162
+ n = len(dataset) if max_samples is None else min(max_samples, len(dataset))
163
+
164
+ for i in range(n):
165
+ sample = dataset[i]
166
+ lr = sample["lr"]
167
+ hr = sample["hr"]
168
+
169
+ if isinstance(lr, torch.Tensor):
170
+ lr = lr.cpu().numpy()
171
+ if isinstance(hr, torch.Tensor):
172
+ hr = hr.cpu().numpy()
173
+ lr = lr.astype(np.float32)
174
+ hr = hr.astype(np.float32)
175
+
176
+ if lr.ndim == 2:
177
+ lr = lr[np.newaxis]
178
+ if hr.ndim == 2:
179
+ hr = hr[np.newaxis]
180
+
181
+ try:
182
+ sr = run_sr(model, lr, in_channels)
183
+ except Exception as e:
184
+ print(f" [WARN] SR failed on sample {i}: {e}")
185
+ continue
186
+
187
+ lr_t = torch.from_numpy(lr)
188
+ sr_t = torch.from_numpy(sr[:lr_t.shape[0]])
189
+ hr_t = torch.from_numpy(hr)
190
+
191
+ try:
192
+ result = metrics_obj.compute(lr=lr_t, sr=sr_t, hr=hr_t)
193
+ except Exception as e:
194
+ print(f" [WARN] Metrics failed on sample {i}: {e}")
195
+ continue
196
+
197
+ for m in METRIC_COLS:
198
+ val = result.get(m)
199
+ if val is not None:
200
+ v = float(val.mean()) if hasattr(val, "mean") else float(val)
201
+ accum[m].append(v)
202
+
203
+ if verbose and (i + 1) % 10 == 0:
204
+ print(f" {i+1}/{n}", end="\r")
205
+
206
+ if verbose:
207
+ print()
208
+
209
+ return {m: float(np.mean(vs)) if vs else float("nan") for m, vs in accum.items()}
210
+
211
+
212
+ # ---------------------------------------------------------------------------
213
+ # HF model card update
214
+ # ---------------------------------------------------------------------------
215
+
216
+ def build_eval_results(
217
+ variant: str,
218
+ results: Dict[str, Dict[str, float]], # dataset → metric → value
219
+ ) -> list:
220
+ """
221
+ Build a list of huggingface_hub.EvalResult objects for one variant.
222
+
223
+ One EvalResult per (dataset × metric) combination.
224
+ """
225
+ from huggingface_hub import EvalResult
226
+
227
+ eval_results = []
228
+ for ds_name, metrics in results.items():
229
+ for metric_name, value in metrics.items():
230
+ if np.isnan(value):
231
+ continue
232
+ eval_results.append(
233
+ EvalResult(
234
+ task_type = "image-to-image",
235
+ task_name = "Super-Resolution",
236
+ dataset_type = f"opensr-test-{ds_name}",
237
+ dataset_name = DATASET_NAMES.get(ds_name, ds_name),
238
+ dataset_config = ds_name,
239
+ metric_type = metric_name,
240
+ metric_name = METRIC_NAMES.get(metric_name, metric_name),
241
+ metric_value = round(value, 6),
242
+ model_name = variant,
243
+ )
244
+ )
245
+ return eval_results
246
+
247
+
248
+ def update_model_card(
249
+ variant: str,
250
+ eval_results: list,
251
+ repo_id: str = "WEO-SAS/sen2sr",
252
+ token: Optional[str] = None,
253
+ push: bool = False,
254
+ ) -> None:
255
+ """
256
+ Load the model card from the HF main branch, merge/replace this variant's
257
+ eval results, and optionally push back.
258
+ """
259
+ from huggingface_hub import ModelCard, ModelCardData
260
+ from huggingface_hub.repocard_data import model_index_to_eval_results, eval_results_to_model_index
261
+
262
+ print(f"\nLoading model card from {repo_id} (main)...")
263
+ try:
264
+ card = ModelCard.load(repo_id, token=token)
265
+ except Exception as e:
266
+ print(f" [WARN] Could not load card: {e}. Creating empty card.")
267
+ card = ModelCard("---\n---\n")
268
+
269
+ existing: list = card.data.eval_results or []
270
+
271
+ # Remove old entries for this variant, keep other variants
272
+ kept = [r for r in existing if getattr(r, "model_name", None) != variant]
273
+ merged = kept + eval_results
274
+
275
+ card.data.eval_results = merged
276
+
277
+ print(f" Model-index now has {len(merged)} EvalResult entries "
278
+ f"({len(eval_results)} from '{variant}', {len(kept)} from other variants).")
279
+
280
+ if push:
281
+ print(f" Pushing updated card to {repo_id}...")
282
+ card.push_to_hub(repo_id, token=token)
283
+ print(" Done.")
284
+ else:
285
+ print(" --push not set; card not pushed. Pass --push to update HF.")
286
+ print("\n--- model-index YAML preview ---")
287
+ print(card.data.to_yaml())
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # Main
292
+ # ---------------------------------------------------------------------------
293
+
294
+ def main():
295
+ # Detect local_dir: script is inside the snapshot directory
296
+ local_dir = str(Path(__file__).parent.resolve())
297
+
298
+ parser = argparse.ArgumentParser(
299
+ description="Evaluate WEO-SAS/sen2sr and update HF model card"
300
+ )
301
+ parser.add_argument(
302
+ "--local-dir", default=local_dir,
303
+ help="Path to the snapshot_download directory (default: script location)",
304
+ )
305
+ parser.add_argument(
306
+ "--datasets", nargs="+", default=DATASETS, choices=DATASETS,
307
+ help="Datasets to evaluate on (default: all 5)",
308
+ )
309
+ parser.add_argument(
310
+ "--max-samples", type=int, default=None,
311
+ help="Cap samples per dataset for a quick smoke-test",
312
+ )
313
+ parser.add_argument(
314
+ "--output", default=None,
315
+ help="CSV output path (default: sen2sr_<variant>_eval.csv in local_dir)",
316
+ )
317
+ parser.add_argument(
318
+ "--repo-id", default="WEO-SAS/sen2sr",
319
+ help="HuggingFace repo whose main-branch card to update",
320
+ )
321
+ parser.add_argument(
322
+ "--token", default=os.environ.get("HF_TOKEN"),
323
+ help="HuggingFace token (default: $HF_TOKEN)",
324
+ )
325
+ parser.add_argument(
326
+ "--push", action="store_true",
327
+ help="Push updated model card to HF after evaluation",
328
+ )
329
+ parser.add_argument(
330
+ "--dry-run", action="store_true",
331
+ help="Print model-index YAML preview without pushing",
332
+ )
333
+ args = parser.parse_args()
334
+
335
+ # Load model + config
336
+ print(f"Loading model from {args.local_dir} ...")
337
+ model, config = load_model_from_local(args.local_dir)
338
+ variant = config.get("variant", "unknown")
339
+ in_channels = config.get("in_channels", 4)
340
+ print(f"Variant : {variant}")
341
+ print(f"In-ch : {in_channels}")
342
+ print(f"Desc : {config.get('description', '')}")
343
+
344
+ # CSV output path
345
+ csv_path = args.output or os.path.join(args.local_dir, f"sen2sr_{variant}_eval.csv")
346
+
347
+ # Evaluate
348
+ all_results: Dict[str, Dict[str, float]] = {}
349
+ rows = []
350
+
351
+ for ds in args.datasets:
352
+ print(f"\n[{variant}] Dataset: {ds}")
353
+ metrics = evaluate_dataset(model, in_channels, ds, args.max_samples)
354
+ if not metrics:
355
+ continue
356
+
357
+ all_results[ds] = metrics
358
+ rows.append({"variant": variant, "dataset": ds, **metrics})
359
+
360
+ print(f" {'Metric':<18} {'Value':>10}")
361
+ print(f" {'-'*30}")
362
+ for m in METRIC_COLS:
363
+ arrow = "↑" if m in ("synthesis", "improvement") else "↓"
364
+ print(f" {m:<18} {metrics.get(m, float('nan')):>9.4f} {arrow}")
365
+
366
+ # Save CSV
367
+ if rows:
368
+ fieldnames = ["variant", "dataset"] + METRIC_COLS
369
+ with open(csv_path, "w", newline="") as f:
370
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
371
+ writer.writeheader()
372
+ writer.writerows(rows)
373
+ print(f"\nCSV saved: {csv_path}")
374
+
375
+ # Build HF EvalResult objects
376
+ if not all_results:
377
+ print("No results to push.")
378
+ return
379
+
380
+ eval_results = build_eval_results(variant, all_results)
381
+ print(f"\nBuilt {len(eval_results)} EvalResult entries for '{variant}'.")
382
+
383
+ # Update model card (optionally push)
384
+ update_model_card(
385
+ variant = variant,
386
+ eval_results = eval_results,
387
+ repo_id = args.repo_id,
388
+ token = args.token,
389
+ push = args.push and not args.dry_run,
390
+ )
391
+
392
+ # Summary table (mean across datasets)
393
+ print("\n" + "="*60)
394
+ print(f"SUMMARY — {variant} — mean across {list(all_results.keys())}")
395
+ print("="*60)
396
+ means = {
397
+ m: float(np.nanmean([v[m] for v in all_results.values() if m in v]))
398
+ for m in METRIC_COLS
399
+ }
400
+ print(f" {'Metric':<18} {'Mean':>10}")
401
+ print(f" {'-'*30}")
402
+ for m in METRIC_COLS:
403
+ arrow = "↑" if m in ("synthesis", "improvement") else "↓"
404
+ print(f" {m:<18} {means.get(m, float('nan')):>9.4f} {arrow}")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ main()