File size: 2,836 Bytes
00c1205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c63be5
 
00c1205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c63be5
 
 
 
 
00c1205
 
 
 
2c63be5
 
 
00c1205
 
 
 
 
 
 
 
 
 
 
 
 
2c63be5
00c1205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c63be5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
model.py
========
Public entry point for WEO-SAS CHM models stored on HuggingFace Hub.

Works identically for both chm-meta (v1) and chm-meta-v2 (v2) — only the
repo changes. All parameters are read from predictor_config.json.

Usage
-----
    from huggingface_hub import snapshot_download
    import sys

    local_dir = snapshot_download("WEO-SAS/chm-meta")   # or chm-meta-v2
    sys.path.insert(0, local_dir)
    from model import Model

    model = Model(local_dir=local_dir)

    # Array inference: (3, H, W) float32 in [0, 1] → (H, W) metres
    chm = model.predict(image)

    # GeoTIFF pipeline
    model.predict_tif("input.tif", "chm_output.tif")
"""

from __future__ import annotations

import importlib.util
import json
import os
import sys
from typing import List, Optional

import numpy as np


def _load_module(name: str, path: str):
    spec   = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


class Model:
    """
    Public CHM model interface for HuggingFace Hub users.

    Parameters
    ----------
    local_dir : str
        Path to the directory returned by ``snapshot_download(repo_id)``.
    **overrides
        Optionally override any value from predictor_config.json, e.g.
        ``Model(local_dir=d, patch_size=448, stride=224)``.
    """

    def __init__(self, local_dir: str, **overrides):
        config_path = os.path.join(local_dir, "predictor_config.json")
        with open(config_path) as f:
            config = json.load(f)

        config.update(overrides)

        if local_dir not in sys.path:
            sys.path.insert(0, local_dir)

        chm_pt    = _load_module("chm_pt", os.path.join(local_dir, "chm_pt.py"))
        self._model = chm_pt.CHMModelPT(local_dir=local_dir, config=config)
        self.description = config.get("description", "")

    def predict(self, image: np.ndarray) -> np.ndarray:
        """
        Run CHM inference on a single image.

        Parameters
        ----------
        image : (3, H, W) float32 numpy array, values in [0, 1]

        Returns
        -------
        (H, W) float32 numpy array — canopy height in metres
        """
        return self._model.predict(image)

    def predict_tif(
        self,
        input_path:  str,
        output_path: str,
        bands:       Optional[List[int]] = None,
    ) -> None:
        """
        Full GeoTIFF CHM pipeline.

        Parameters
        ----------
        input_path  : path to input RGB or multi-band GeoTIFF
        output_path : output path for the CHM GeoTIFF (1 band, metres)
        bands       : 0-based band indices to use as RGB (default: [0, 1, 2])
        """
        self._model.predict_tif(input_path, output_path, bands)