Depth Estimation
Transformers
Safetensors
English
chmv2
dinov3
canopy-height
chm
Eval Results (legacy)
Instructions to use WEO-SAS/chm-meta-v2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WEO-SAS/chm-meta-v2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("depth-estimation", model="WEO-SAS/chm-meta-v2")# Load model directly from transformers import AutoModelForDepthEstimation model = AutoModelForDepthEstimation.from_pretrained("WEO-SAS/chm-meta-v2", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| predictor.py | |
| ============ | |
| Unified CHM inference for WEO-SAS/chm-meta (v1) and WEO-SAS/chm-meta-v2 (v2). | |
| Both versions expose the same interface — only the model directory changes: | |
| predictor = CHMPredictor("./chm-meta") # v1: SSL ViT-H + DPT | |
| predictor = CHMPredictor("./chm-meta-v2") # v2: DINOv3 ViT-L + DPT | |
| chm = predictor.predict(image) # (3,H,W) float32 → (H,W) metres | |
| predictor.predict_tif("in.tif", "out.tif") # full GeoTIFF pipeline | |
| When called from chm_pt.py the pre-built model is injected via model= so that | |
| weights are not loaded twice. | |
| Requirements | |
| ------------ | |
| Both versions: torch, numpy, rasterio, Pillow | |
| v1 only : pytorch_lightning (pip install pytorch_lightning) | |
| v2 only : transformers with CHMv2 support | |
| (pip install git+https://github.com/huggingface/transformers.git) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| import rasterio | |
| class CHMPredictor: | |
| """ | |
| Canopy Height Model predictor — works with both chm-meta (v1) and | |
| chm-meta-v2 (v2) by reading predictor_config.json from the model directory. | |
| Parameters | |
| ---------- | |
| model_dir : local path to a downloaded WEO-SAS CHM model repo | |
| device : torch device (auto-detected if None) | |
| model : pre-built model; bypasses weights loading (used by chm_pt.py) | |
| processor : pre-built HF processor for v2; bypasses processor loading | |
| """ | |
| def __init__( | |
| self, | |
| model_dir: str, | |
| device: Optional[torch.device] = None, | |
| model = None, | |
| processor = None, | |
| ): | |
| model_dir = Path(model_dir) | |
| with open(model_dir / "predictor_config.json") as f: | |
| cfg = json.load(f) | |
| self.model_dir = model_dir | |
| self.patch_size = cfg["patch_size"] | |
| self.stride = cfg["stride"] | |
| self.image_mean = np.array(cfg["image_mean"], dtype=np.float32).reshape(3, 1, 1) | |
| self.image_std = np.array(cfg["image_std"], dtype=np.float32).reshape(3, 1, 1) | |
| self.model_version = cfg["model_version"] | |
| self.description = cfg.get("description", "") | |
| self.device = device or torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| if model is not None: | |
| # Pre-built model injected by chm_pt.py — skip weight loading | |
| self.model = model.to(self.device) | |
| self.processor = processor | |
| elif self.model_version == "v1": | |
| self._load_v1(model_dir, model_dir / cfg["weights_file"]) | |
| elif self.model_version == "v2": | |
| self._load_v2(model_dir) | |
| else: | |
| raise ValueError(f"Unknown model_version '{self.model_version}' in predictor_config.json") | |
| self.model.eval() | |
| # ------------------------------------------------------------------ | |
| # Model loading (only used when model= is not injected) | |
| # ------------------------------------------------------------------ | |
| def _load_v1(self, model_dir: Path, weights_path: Path) -> None: | |
| sys.path.insert(0, str(model_dir)) | |
| from model_chm import SSLModule # noqa: PLC0415 | |
| self.model = SSLModule(ssl_path=str(weights_path), local_path=str(weights_path)) | |
| self.processor = None | |
| self.model.to(self.device) | |
| def _load_v2(self, model_dir: Path) -> None: | |
| try: | |
| from transformers import CHMv2ForDepthEstimation, CHMv2ImageProcessorFast | |
| except ImportError as exc: | |
| raise ImportError( | |
| "v2 requires a transformers build that includes CHMv2. " | |
| "Install: pip install git+https://github.com/huggingface/transformers.git" | |
| ) from exc | |
| self.model = CHMv2ForDepthEstimation.from_pretrained(str(model_dir)) | |
| self.processor = CHMv2ImageProcessorFast.from_pretrained(str(model_dir)) | |
| self.model.to(self.device) | |
| # ------------------------------------------------------------------ | |
| # Per-tile inference | |
| # ------------------------------------------------------------------ | |
| def _infer_tile_v1(self, tile: np.ndarray) -> np.ndarray: | |
| """tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)""" | |
| normalised = (tile - self.image_mean) / self.image_std | |
| x = torch.from_numpy(normalised).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| out = self.model(x) # (1, 1, H, W) | |
| return out.squeeze().cpu().numpy() | |
| def _infer_tile_v2(self, tile: np.ndarray) -> np.ndarray: | |
| """tile: (3, patch_size, patch_size) float32 in [0, 1] → (patch_size, patch_size)""" | |
| from PIL import Image # noqa: PLC0415 | |
| arr_hwc = (tile * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0) | |
| pil_img = Image.fromarray(arr_hwc) | |
| H, W = pil_img.height, pil_img.width | |
| inputs = self.processor(images=pil_img, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| depth = self.processor.post_process_depth_estimation( | |
| outputs, target_sizes=[(H, W)] | |
| )[0]["predicted_depth"] | |
| return depth.cpu().numpy() # (H, W) | |
| def _infer_tile(self, tile: np.ndarray) -> np.ndarray: | |
| if self.model_version == "v1": | |
| return self._infer_tile_v1(tile) | |
| return self._infer_tile_v2(tile) | |
| # ------------------------------------------------------------------ | |
| # Tiled inference | |
| # ------------------------------------------------------------------ | |
| def predict(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| Run CHM inference on a (3, H, W) float32 image. | |
| For images larger than patch_size × patch_size, uses a sliding window | |
| with 50 % overlap and simple averaging at tile boundaries. | |
| Parameters | |
| ---------- | |
| image : (3, H, W) float32, values in [0, 1] | |
| Returns | |
| ------- | |
| (H, W) float32 — canopy height in metres | |
| """ | |
| if image.ndim != 3 or image.shape[0] != 3: | |
| raise ValueError(f"Expected (3, H, W), got {image.shape}") | |
| _, H, W = image.shape | |
| ps, st = self.patch_size, self.stride | |
| if H <= ps and W <= ps: | |
| pad = np.zeros((3, ps, ps), dtype=np.float32) | |
| pad[:, :H, :W] = image | |
| return self._infer_tile(pad)[:H, :W] | |
| output = np.zeros((H, W), dtype=np.float64) | |
| count = np.zeros((H, W), dtype=np.float64) | |
| ys = sorted(set(list(range(0, H - ps, st)) + [max(0, H - ps)])) | |
| xs = sorted(set(list(range(0, W - ps, st)) + [max(0, W - ps)])) | |
| for y in ys: | |
| for x in xs: | |
| y2, x2 = min(y + ps, H), min(x + ps, W) | |
| th, tw = y2 - y, x2 - x | |
| tile = np.zeros((3, ps, ps), dtype=np.float32) | |
| tile[:, :th, :tw] = image[:, y:y2, x:x2] | |
| pred = self._infer_tile(tile) | |
| output[y:y2, x:x2] += pred[:th, :tw] | |
| count [y:y2, x:x2] += 1.0 | |
| return (output / np.maximum(count, 1)).astype(np.float32) | |
| # ------------------------------------------------------------------ | |
| # GeoTIFF pipeline | |
| # ------------------------------------------------------------------ | |
| def predict_tif( | |
| self, | |
| input_path: str, | |
| output_path: str, | |
| bands: Optional[List[int]] = None, | |
| ) -> None: | |
| """ | |
| Full GeoTIFF CHM pipeline. | |
| Reads the input GeoTIFF, runs CHM inference, and writes a single-band | |
| float32 GeoTIFF with canopy height in metres at the same resolution. | |
| Parameters | |
| ---------- | |
| input_path : path to input RGB or multi-band GeoTIFF | |
| output_path : output path for the CHM GeoTIFF (1 band, metres) | |
| bands : 0-based band indices to use as RGB (default: [0, 1, 2]) | |
| """ | |
| bands = bands or [0, 1, 2] | |
| with rasterio.open(input_path) as src: | |
| arr = src.read([b + 1 for b in bands]).astype(np.float32) | |
| profile = src.profile.copy() | |
| for b in range(arr.shape[0]): | |
| vmin = float(np.nanpercentile(arr[b], 1)) | |
| vmax = float(np.nanpercentile(arr[b], 99)) | |
| arr[b] = np.clip((arr[b] - vmin) / max(vmax - vmin, 1e-6), 0.0, 1.0) | |
| print(f"CHM inference model={self.model_version} input={arr.shape} {input_path}") | |
| chm = self.predict(arr) | |
| print(f"Output shape {chm.shape} range [{chm.min():.2f}, {chm.max():.2f}] m") | |
| out_profile = profile.copy() | |
| out_profile.update(count=1, dtype="float32", compress="lzw", nodata=None) | |
| out_profile.pop("photometric", None) | |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) | |
| with rasterio.open(output_path, "w", **out_profile) as dst: | |
| dst.write(chm[np.newaxis]) | |
| print(f"Written: {output_path}") | |