Spaces:
Sleeping
Sleeping
| # loader.py | |
| import os | |
| import json | |
| import numpy as np | |
| # ββ UPDATE THIS after creating your HF dataset repo ββββββββββββββββββββββββββ | |
| HF_REPO_ID = "sumit1703/pm25-forecasting-data" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| LOCAL_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") | |
| HUB_CACHE_DIR = "/tmp/pm25_data" | |
| REQUIRED_FILES = [ | |
| "demo_preds.npy", | |
| "demo_inputs.npy", | |
| "lat_lon.npy", | |
| "demo_stats.json", | |
| "sample_indices.npy", | |
| ] | |
| H, W = 140, 124 | |
| def load_demo_data() -> dict: | |
| """ | |
| Returns a dict with keys: | |
| preds np.ndarray (22, 140, 124, 16) | |
| inputs np.ndarray (22, 10, 140, 124) | |
| lat np.ndarray (140, 124) or None | |
| lon np.ndarray (140, 124) or None | |
| stats dict {feature: {mean, std}} | |
| sample_indices np.ndarray (22,) | |
| """ | |
| # Check if all files exist locally | |
| all_local = all( | |
| os.path.exists(os.path.join(LOCAL_DATA_DIR, f)) | |
| for f in REQUIRED_FILES | |
| ) | |
| if all_local: | |
| print("[loader] Using local data/ folder.") | |
| data_dir = LOCAL_DATA_DIR | |
| else: | |
| print("[loader] Local data/ not found β downloading from HF Hub...") | |
| data_dir = _download_from_hub() | |
| return _load_files(data_dir) | |
| def _download_from_hub() -> str: | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs(HUB_CACHE_DIR, exist_ok=True) | |
| for fname in REQUIRED_FILES: | |
| dest = os.path.join(HUB_CACHE_DIR, fname) | |
| if os.path.exists(dest): | |
| print(f" [cache hit] {fname}") | |
| continue | |
| print(f" [downloading] {fname} ...") | |
| hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=fname, | |
| repo_type="dataset", | |
| local_dir=HUB_CACHE_DIR, | |
| ) | |
| print(f" [done] {fname}") | |
| return HUB_CACHE_DIR | |
| def _load_files(data_dir: str) -> dict: | |
| preds = np.load(os.path.join(data_dir, "demo_preds.npy")) | |
| inputs = np.load(os.path.join(data_dir, "demo_inputs.npy")) | |
| lat_lon = np.load(os.path.join(data_dir, "lat_lon.npy")) | |
| sample_indices = np.load(os.path.join(data_dir, "sample_indices.npy")) | |
| with open(os.path.join(data_dir, "demo_stats.json")) as f: | |
| stats = json.load(f) | |
| # Validate shapes | |
| assert preds.ndim == 4, f"demo_preds.npy expected 4D, got {preds.shape}" | |
| assert inputs.ndim == 4, f"demo_inputs.npy expected 4D, got {inputs.shape}" | |
| # Parse lat/lon β handle both (2, H, W) and (H, W, 2) | |
| lat, lon = None, None | |
| if lat_lon.shape == (2, H, W): | |
| lat, lon = lat_lon[0], lat_lon[1] | |
| elif lat_lon.shape == (H, W, 2): | |
| lat, lon = lat_lon[..., 0], lat_lon[..., 1] | |
| else: | |
| print(f"[loader] Warning: unexpected lat_lon shape {lat_lon.shape} β axes will show pixel indices") | |
| print(f"[loader] Loaded β preds:{preds.shape} inputs:{inputs.shape}") | |
| return { | |
| "preds": preds, | |
| "inputs": inputs, | |
| "lat": lat, | |
| "lon": lon, | |
| "stats": stats, | |
| "sample_indices": sample_indices, | |
| } | |