""" predictor.py ============ Unified CHM inference for WEO-SAS/chm-meta (v1) and WEO-SAS/chm-meta-v2 (v2). Both versions expose the same interface — only the model directory changes: predictor = CHMPredictor("./chm-meta") # v1: SSL ViT-H + DPT predictor = CHMPredictor("./chm-meta-v2") # v2: DINOv3 ViT-L + DPT chm = predictor.predict(image) # (3,H,W) float32 → (H,W) metres predictor.predict_tif("in.tif", "out.tif") # full GeoTIFF pipeline When called from chm_pt.py the pre-built model is injected via model= so that weights are not loaded twice. Requirements ------------ Both versions: torch, numpy, rasterio, Pillow v1 only : pytorch_lightning (pip install pytorch_lightning) v2 only : transformers with CHMv2 support (pip install git+https://github.com/huggingface/transformers.git) """ from __future__ import annotations import json import sys from pathlib import Path from typing import List, Optional import numpy as np import torch import rasterio class CHMPredictor: """ Canopy Height Model predictor — works with both chm-meta (v1) and chm-meta-v2 (v2) by reading predictor_config.json from the model directory. Parameters ---------- model_dir : local path to a downloaded WEO-SAS CHM model repo device : torch device (auto-detected if None) model : pre-built model; bypasses weights loading (used by chm_pt.py) processor : pre-built HF processor for v2; bypasses processor loading """ def __init__( self, model_dir: str, device: Optional[torch.device] = None, model = None, processor = None, ): model_dir = Path(model_dir) with open(model_dir / "predictor_config.json") as f: cfg = json.load(f) self.model_dir = model_dir self.patch_size = cfg["patch_size"] self.stride = cfg["stride"] self.image_mean = np.array(cfg["image_mean"], dtype=np.float32).reshape(3, 1, 1) self.image_std = np.array(cfg["image_std"], dtype=np.float32).reshape(3, 1, 1) self.model_version = cfg["model_version"] self.description = cfg.get("description", "") self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) if model is not None: # Pre-built model injected by chm_pt.py — skip weight loading self.model = model.to(self.device) self.processor = processor elif self.model_version == "v1": self._load_v1(model_dir, model_dir / cfg["weights_file"]) elif self.model_version == "v2": self._load_v2(model_dir) else: raise ValueError(f"Unknown model_version '{self.model_version}' in predictor_config.json") self.model.eval() # ------------------------------------------------------------------ # Model loading (only used when model= is not injected) # ------------------------------------------------------------------ def _load_v1(self, model_dir: Path, weights_path: Path) -> None: sys.path.insert(0, str(model_dir)) from model_chm import SSLModule # noqa: PLC0415 self.model = SSLModule(ssl_path=str(weights_path), local_path=str(weights_path)) self.processor = None self.model.to(self.device) def _load_v2(self, model_dir: Path) -> None: try: from transformers import CHMv2ForDepthEstimation, CHMv2ImageProcessorFast except ImportError as exc: raise ImportError( "v2 requires a transformers build that includes CHMv2. " "Install: pip install git+https://github.com/huggingface/transformers.git" ) from exc self.model = CHMv2ForDepthEstimation.from_pretrained(str(model_dir)) self.processor = CHMv2ImageProcessorFast.from_pretrained(str(model_dir)) self.model.to(self.device) # ------------------------------------------------------------------ # Per-tile inference # ------------------------------------------------------------------ def _infer_tile_v1(self, tile: np.ndarray) -> np.ndarray: """tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)""" normalised = (tile - self.image_mean) / self.image_std x = torch.from_numpy(normalised).float().unsqueeze(0).to(self.device) with torch.no_grad(): out = self.model(x) # (1, 1, H, W) return out.squeeze().cpu().numpy() def _infer_tile_v2(self, tile: np.ndarray) -> np.ndarray: """tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)""" from PIL import Image # noqa: PLC0415 arr_hwc = (tile * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0) pil_img = Image.fromarray(arr_hwc) H, W = pil_img.height, pil_img.width inputs = self.processor(images=pil_img, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) depth = self.processor.post_process_depth_estimation( outputs, target_sizes=[(H, W)] )[0]["predicted_depth"] return depth.cpu().numpy() # (H, W) def _infer_tile(self, tile: np.ndarray) -> np.ndarray: if self.model_version == "v1": return self._infer_tile_v1(tile) return self._infer_tile_v2(tile) # ------------------------------------------------------------------ # Tiled inference # ------------------------------------------------------------------ def predict(self, image: np.ndarray) -> np.ndarray: """ Run CHM inference on a (3, H, W) float32 image. For images larger than patch_size × patch_size, uses a sliding window with 50 % overlap and simple averaging at tile boundaries. Parameters ---------- image : (3, H, W) float32, values in [0, 1] Returns ------- (H, W) float32 — canopy height in metres """ if image.ndim != 3 or image.shape[0] != 3: raise ValueError(f"Expected (3, H, W), got {image.shape}") _, H, W = image.shape ps, st = self.patch_size, self.stride if H <= ps and W <= ps: pad = np.zeros((3, ps, ps), dtype=np.float32) pad[:, :H, :W] = image return self._infer_tile(pad)[:H, :W] output = np.zeros((H, W), dtype=np.float64) count = np.zeros((H, W), dtype=np.float64) ys = sorted(set(list(range(0, H - ps, st)) + [max(0, H - ps)])) xs = sorted(set(list(range(0, W - ps, st)) + [max(0, W - ps)])) for y in ys: for x in xs: y2, x2 = min(y + ps, H), min(x + ps, W) th, tw = y2 - y, x2 - x tile = np.zeros((3, ps, ps), dtype=np.float32) tile[:, :th, :tw] = image[:, y:y2, x:x2] pred = self._infer_tile(tile) output[y:y2, x:x2] += pred[:th, :tw] count [y:y2, x:x2] += 1.0 return (output / np.maximum(count, 1)).astype(np.float32) # ------------------------------------------------------------------ # GeoTIFF pipeline # ------------------------------------------------------------------ def predict_tif( self, input_path: str, output_path: str, bands: Optional[List[int]] = None, ) -> None: """ Full GeoTIFF CHM pipeline. Reads the input GeoTIFF, runs CHM inference, and writes a single-band float32 GeoTIFF with canopy height in metres at the same resolution. Parameters ---------- input_path : path to input RGB or multi-band GeoTIFF output_path : output path for the CHM GeoTIFF (1 band, metres) bands : 0-based band indices to use as RGB (default: [0, 1, 2]) """ bands = bands or [0, 1, 2] with rasterio.open(input_path) as src: arr = src.read([b + 1 for b in bands]).astype(np.float32) profile = src.profile.copy() for b in range(arr.shape[0]): vmin = float(np.nanpercentile(arr[b], 1)) vmax = float(np.nanpercentile(arr[b], 99)) arr[b] = np.clip((arr[b] - vmin) / max(vmax - vmin, 1e-6), 0.0, 1.0) print(f"CHM inference model={self.model_version} input={arr.shape} {input_path}") chm = self.predict(arr) print(f"Output shape {chm.shape} range [{chm.min():.2f}, {chm.max():.2f}] m") out_profile = profile.copy() out_profile.update(count=1, dtype="float32", compress="lzw", nodata=None) out_profile.pop("photometric", None) Path(output_path).parent.mkdir(parents=True, exist_ok=True) with rasterio.open(output_path, "w", **out_profile) as dst: dst.write(chm[np.newaxis]) print(f"Written: {output_path}")