""" model.py ======== Public entry point for WEO-SAS CHM models stored on HuggingFace Hub. Works identically for both chm-meta (v1) and chm-meta-v2 (v2) — only the repo changes. All parameters are read from predictor_config.json. Usage ----- from huggingface_hub import snapshot_download import sys local_dir = snapshot_download("WEO-SAS/chm-meta") # or chm-meta-v2 sys.path.insert(0, local_dir) from model import Model model = Model(local_dir=local_dir) # Array inference: (3, H, W) float32 in [0, 1] → (H, W) metres chm = model.predict(image) # GeoTIFF pipeline model.predict_tif("input.tif", "chm_output.tif") """ from __future__ import annotations import importlib.util import json import os import sys from typing import List, Optional import numpy as np def _load_module(name: str, path: str): spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) sys.modules[name] = module spec.loader.exec_module(module) return module class Model: """ Public CHM model interface for HuggingFace Hub users. Parameters ---------- local_dir : str Path to the directory returned by ``snapshot_download(repo_id)``. **overrides Optionally override any value from predictor_config.json, e.g. ``Model(local_dir=d, patch_size=448, stride=224)``. """ def __init__(self, local_dir: str, **overrides): config_path = os.path.join(local_dir, "predictor_config.json") with open(config_path) as f: config = json.load(f) config.update(overrides) if local_dir not in sys.path: sys.path.insert(0, local_dir) chm_pt = _load_module("chm_pt", os.path.join(local_dir, "chm_pt.py")) self._model = chm_pt.CHMModelPT(local_dir=local_dir, config=config) self.description = config.get("description", "") def predict(self, image: np.ndarray) -> np.ndarray: """ Run CHM inference on a single image. Parameters ---------- image : (3, H, W) float32 numpy array, values in [0, 1] Returns ------- (H, W) float32 numpy array — canopy height in metres """ return self._model.predict(image) def predict_tif( self, input_path: str, output_path: str, bands: Optional[List[int]] = None, ) -> None: """ Full GeoTIFF CHM pipeline. Parameters ---------- input_path : path to input RGB or multi-band GeoTIFF output_path : output path for the CHM GeoTIFF (1 band, metres) bands : 0-based band indices to use as RGB (default: [0, 1, 2]) """ self._model.predict_tif(input_path, output_path, bands)