File size: 3,418 Bytes
6d01d4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# loader.py
import os
import json
import numpy as np

# ── UPDATE THIS after creating your HF dataset repo ──────────────────────────
HF_REPO_ID = "sumit1703/pm25-forecasting-data"
# ─────────────────────────────────────────────────────────────────────────────

LOCAL_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
HUB_CACHE_DIR  = "/tmp/pm25_data"

REQUIRED_FILES = [
    "demo_preds.npy",
    "demo_inputs.npy",
    "lat_lon.npy",
    "demo_stats.json",
    "sample_indices.npy",
]

H, W = 140, 124


def load_demo_data() -> dict:
    """
    Returns a dict with keys:
        preds          np.ndarray  (22, 140, 124, 16)
        inputs         np.ndarray  (22, 10, 140, 124)
        lat            np.ndarray  (140, 124)  or None
        lon            np.ndarray  (140, 124)  or None
        stats          dict        {feature: {mean, std}}
        sample_indices np.ndarray  (22,)
    """
    # Check if all files exist locally
    all_local = all(
        os.path.exists(os.path.join(LOCAL_DATA_DIR, f))
        for f in REQUIRED_FILES
    )

    if all_local:
        print("[loader] Using local data/ folder.")
        data_dir = LOCAL_DATA_DIR
    else:
        print("[loader] Local data/ not found β€” downloading from HF Hub...")
        data_dir = _download_from_hub()

    return _load_files(data_dir)


def _download_from_hub() -> str:
    from huggingface_hub import hf_hub_download

    os.makedirs(HUB_CACHE_DIR, exist_ok=True)

    for fname in REQUIRED_FILES:
        dest = os.path.join(HUB_CACHE_DIR, fname)
        if os.path.exists(dest):
            print(f"  [cache hit] {fname}")
            continue
        print(f"  [downloading] {fname} ...")
        hf_hub_download(
            repo_id=HF_REPO_ID,
            filename=fname,
            repo_type="dataset",
            local_dir=HUB_CACHE_DIR,
        )
        print(f"  [done] {fname}")

    return HUB_CACHE_DIR


def _load_files(data_dir: str) -> dict:
    preds          = np.load(os.path.join(data_dir, "demo_preds.npy"))
    inputs         = np.load(os.path.join(data_dir, "demo_inputs.npy"))
    lat_lon        = np.load(os.path.join(data_dir, "lat_lon.npy"))
    sample_indices = np.load(os.path.join(data_dir, "sample_indices.npy"))

    with open(os.path.join(data_dir, "demo_stats.json")) as f:
        stats = json.load(f)

    # Validate shapes
    assert preds.ndim  == 4, f"demo_preds.npy expected 4D, got {preds.shape}"
    assert inputs.ndim == 4, f"demo_inputs.npy expected 4D, got {inputs.shape}"

    # Parse lat/lon β€” handle both (2, H, W) and (H, W, 2)
    lat, lon = None, None
    if lat_lon.shape == (2, H, W):
        lat, lon = lat_lon[0], lat_lon[1]
    elif lat_lon.shape == (H, W, 2):
        lat, lon = lat_lon[..., 0], lat_lon[..., 1]
    else:
        print(f"[loader] Warning: unexpected lat_lon shape {lat_lon.shape} β€” axes will show pixel indices")

    print(f"[loader] Loaded β€” preds:{preds.shape}  inputs:{inputs.shape}")

    return {
        "preds":          preds,
        "inputs":         inputs,
        "lat":            lat,
        "lon":            lon,
        "stats":          stats,
        "sample_indices": sample_indices,
    }