chm-meta-v2 / predictor.py
RhodWeo's picture
Add/update predictor.py
84591e4 verified
"""
predictor.py
============
Unified CHM inference for WEO-SAS/chm-meta (v1) and WEO-SAS/chm-meta-v2 (v2).
Both versions expose the same interface — only the model directory changes:
predictor = CHMPredictor("./chm-meta") # v1: SSL ViT-H + DPT
predictor = CHMPredictor("./chm-meta-v2") # v2: DINOv3 ViT-L + DPT
chm = predictor.predict(image) # (3,H,W) float32 → (H,W) metres
predictor.predict_tif("in.tif", "out.tif") # full GeoTIFF pipeline
When called from chm_pt.py the pre-built model is injected via model= so that
weights are not loaded twice.
Requirements
------------
Both versions: torch, numpy, rasterio, Pillow
v1 only : pytorch_lightning (pip install pytorch_lightning)
v2 only : transformers with CHMv2 support
(pip install git+https://github.com/huggingface/transformers.git)
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import rasterio
class CHMPredictor:
"""
Canopy Height Model predictor — works with both chm-meta (v1) and
chm-meta-v2 (v2) by reading predictor_config.json from the model directory.
Parameters
----------
model_dir : local path to a downloaded WEO-SAS CHM model repo
device : torch device (auto-detected if None)
model : pre-built model; bypasses weights loading (used by chm_pt.py)
processor : pre-built HF processor for v2; bypasses processor loading
"""
def __init__(
self,
model_dir: str,
device: Optional[torch.device] = None,
model = None,
processor = None,
):
model_dir = Path(model_dir)
with open(model_dir / "predictor_config.json") as f:
cfg = json.load(f)
self.model_dir = model_dir
self.patch_size = cfg["patch_size"]
self.stride = cfg["stride"]
self.image_mean = np.array(cfg["image_mean"], dtype=np.float32).reshape(3, 1, 1)
self.image_std = np.array(cfg["image_std"], dtype=np.float32).reshape(3, 1, 1)
self.model_version = cfg["model_version"]
self.description = cfg.get("description", "")
self.device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if model is not None:
# Pre-built model injected by chm_pt.py — skip weight loading
self.model = model.to(self.device)
self.processor = processor
elif self.model_version == "v1":
self._load_v1(model_dir, model_dir / cfg["weights_file"])
elif self.model_version == "v2":
self._load_v2(model_dir)
else:
raise ValueError(f"Unknown model_version '{self.model_version}' in predictor_config.json")
self.model.eval()
# ------------------------------------------------------------------
# Model loading (only used when model= is not injected)
# ------------------------------------------------------------------
def _load_v1(self, model_dir: Path, weights_path: Path) -> None:
sys.path.insert(0, str(model_dir))
from model_chm import SSLModule # noqa: PLC0415
self.model = SSLModule(ssl_path=str(weights_path), local_path=str(weights_path))
self.processor = None
self.model.to(self.device)
def _load_v2(self, model_dir: Path) -> None:
try:
from transformers import CHMv2ForDepthEstimation, CHMv2ImageProcessorFast
except ImportError as exc:
raise ImportError(
"v2 requires a transformers build that includes CHMv2. "
"Install: pip install git+https://github.com/huggingface/transformers.git"
) from exc
self.model = CHMv2ForDepthEstimation.from_pretrained(str(model_dir))
self.processor = CHMv2ImageProcessorFast.from_pretrained(str(model_dir))
self.model.to(self.device)
# ------------------------------------------------------------------
# Per-tile inference
# ------------------------------------------------------------------
def _infer_tile_v1(self, tile: np.ndarray) -> np.ndarray:
"""tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)"""
normalised = (tile - self.image_mean) / self.image_std
x = torch.from_numpy(normalised).float().unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.model(x) # (1, 1, H, W)
return out.squeeze().cpu().numpy()
def _infer_tile_v2(self, tile: np.ndarray) -> np.ndarray:
"""tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)"""
from PIL import Image # noqa: PLC0415
arr_hwc = (tile * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
pil_img = Image.fromarray(arr_hwc)
H, W = pil_img.height, pil_img.width
inputs = self.processor(images=pil_img, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
depth = self.processor.post_process_depth_estimation(
outputs, target_sizes=[(H, W)]
)[0]["predicted_depth"]
return depth.cpu().numpy() # (H, W)
def _infer_tile(self, tile: np.ndarray) -> np.ndarray:
if self.model_version == "v1":
return self._infer_tile_v1(tile)
return self._infer_tile_v2(tile)
# ------------------------------------------------------------------
# Tiled inference
# ------------------------------------------------------------------
def predict(self, image: np.ndarray) -> np.ndarray:
"""
Run CHM inference on a (3, H, W) float32 image.
For images larger than patch_size × patch_size, uses a sliding window
with 50 % overlap and simple averaging at tile boundaries.
Parameters
----------
image : (3, H, W) float32, values in [0, 1]
Returns
-------
(H, W) float32 — canopy height in metres
"""
if image.ndim != 3 or image.shape[0] != 3:
raise ValueError(f"Expected (3, H, W), got {image.shape}")
_, H, W = image.shape
ps, st = self.patch_size, self.stride
if H <= ps and W <= ps:
pad = np.zeros((3, ps, ps), dtype=np.float32)
pad[:, :H, :W] = image
return self._infer_tile(pad)[:H, :W]
output = np.zeros((H, W), dtype=np.float64)
count = np.zeros((H, W), dtype=np.float64)
ys = sorted(set(list(range(0, H - ps, st)) + [max(0, H - ps)]))
xs = sorted(set(list(range(0, W - ps, st)) + [max(0, W - ps)]))
for y in ys:
for x in xs:
y2, x2 = min(y + ps, H), min(x + ps, W)
th, tw = y2 - y, x2 - x
tile = np.zeros((3, ps, ps), dtype=np.float32)
tile[:, :th, :tw] = image[:, y:y2, x:x2]
pred = self._infer_tile(tile)
output[y:y2, x:x2] += pred[:th, :tw]
count [y:y2, x:x2] += 1.0
return (output / np.maximum(count, 1)).astype(np.float32)
# ------------------------------------------------------------------
# GeoTIFF pipeline
# ------------------------------------------------------------------
def predict_tif(
self,
input_path: str,
output_path: str,
bands: Optional[List[int]] = None,
) -> None:
"""
Full GeoTIFF CHM pipeline.
Reads the input GeoTIFF, runs CHM inference, and writes a single-band
float32 GeoTIFF with canopy height in metres at the same resolution.
Parameters
----------
input_path : path to input RGB or multi-band GeoTIFF
output_path : output path for the CHM GeoTIFF (1 band, metres)
bands : 0-based band indices to use as RGB (default: [0, 1, 2])
"""
bands = bands or [0, 1, 2]
with rasterio.open(input_path) as src:
arr = src.read([b + 1 for b in bands]).astype(np.float32)
profile = src.profile.copy()
for b in range(arr.shape[0]):
vmin = float(np.nanpercentile(arr[b], 1))
vmax = float(np.nanpercentile(arr[b], 99))
arr[b] = np.clip((arr[b] - vmin) / max(vmax - vmin, 1e-6), 0.0, 1.0)
print(f"CHM inference model={self.model_version} input={arr.shape} {input_path}")
chm = self.predict(arr)
print(f"Output shape {chm.shape} range [{chm.min():.2f}, {chm.max():.2f}] m")
out_profile = profile.copy()
out_profile.update(count=1, dtype="float32", compress="lzw", nodata=None)
out_profile.pop("photometric", None)
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
with rasterio.open(output_path, "w", **out_profile) as dst:
dst.write(chm[np.newaxis])
print(f"Written: {output_path}")