chm-meta-v2 / model.py
RhodWeo's picture
Add/update model.py
2c63be5 verified
"""
model.py
========
Public entry point for WEO-SAS CHM models stored on HuggingFace Hub.
Works identically for both chm-meta (v1) and chm-meta-v2 (v2) — only the
repo changes. All parameters are read from predictor_config.json.
Usage
-----
from huggingface_hub import snapshot_download
import sys
local_dir = snapshot_download("WEO-SAS/chm-meta") # or chm-meta-v2
sys.path.insert(0, local_dir)
from model import Model
model = Model(local_dir=local_dir)
# Array inference: (3, H, W) float32 in [0, 1] → (H, W) metres
chm = model.predict(image)
# GeoTIFF pipeline
model.predict_tif("input.tif", "chm_output.tif")
"""
from __future__ import annotations
import importlib.util
import json
import os
import sys
from typing import List, Optional
import numpy as np
def _load_module(name: str, path: str):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module
class Model:
"""
Public CHM model interface for HuggingFace Hub users.
Parameters
----------
local_dir : str
Path to the directory returned by ``snapshot_download(repo_id)``.
**overrides
Optionally override any value from predictor_config.json, e.g.
``Model(local_dir=d, patch_size=448, stride=224)``.
"""
def __init__(self, local_dir: str, **overrides):
config_path = os.path.join(local_dir, "predictor_config.json")
with open(config_path) as f:
config = json.load(f)
config.update(overrides)
if local_dir not in sys.path:
sys.path.insert(0, local_dir)
chm_pt = _load_module("chm_pt", os.path.join(local_dir, "chm_pt.py"))
self._model = chm_pt.CHMModelPT(local_dir=local_dir, config=config)
self.description = config.get("description", "")
def predict(self, image: np.ndarray) -> np.ndarray:
"""
Run CHM inference on a single image.
Parameters
----------
image : (3, H, W) float32 numpy array, values in [0, 1]
Returns
-------
(H, W) float32 numpy array — canopy height in metres
"""
return self._model.predict(image)
def predict_tif(
self,
input_path: str,
output_path: str,
bands: Optional[List[int]] = None,
) -> None:
"""
Full GeoTIFF CHM pipeline.
Parameters
----------
input_path : path to input RGB or multi-band GeoTIFF
output_path : output path for the CHM GeoTIFF (1 band, metres)
bands : 0-based band indices to use as RGB (default: [0, 1, 2])
"""
self._model.predict_tif(input_path, output_path, bands)