"""HF Inference Endpoint handler for Prithvi-EO-2.0-300M + TerraMind-1.0-Base. Deployed to pokkiri/eo-multibackbone-endpoint (framework="custom"). Request format: {"model": "prithvi"|"terramind", "inputs": } where inputs are normalised float32 arrays. For Prithvi: 6 channels in order [B02, B03, B04, B05, B06, B07], normalised For TerraMind: 12 channels Sentinel-2 L2A bands, normalised Response format: {"embeddings": [[float, ...], ...]} shape (B, embed_dim) {"error": "message"} on failure Prithvi embed_dim = 1024 (mean-pooled spatial tokens from last encoder block) TerraMind embed_dim = 768 (mean-pooled output tokens) """ from __future__ import annotations import json import os import sys from io import BytesIO from pathlib import Path import numpy as np import torch class EndpointHandler: def __init__(self, path: str = ""): self._path = path self._prithvi = None self._terramind = None self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[handler] device: {self._device}") self._load_prithvi(path) self._load_terramind() def _load_prithvi(self, path: str) -> None: try: # prithvi_mae.py lives alongside handler.py in /app/ when path is empty search_paths = [path, "/app", os.path.dirname(os.path.abspath(__file__))] for sp in search_paths: if sp and sp not in sys.path: sys.path.insert(0, sp) from prithvi_mae import PrithviMAE # noqa: PLC0415 # Try config.json in path first, then /app/ for cfg_dir in [path, "/app", os.path.dirname(os.path.abspath(__file__))]: cfg_path = os.path.join(cfg_dir, "config.json") if cfg_dir else "config.json" if os.path.exists(cfg_path): break else: cfg_path = "config.json" with open(cfg_path) as fh: cfg = json.load(fh) pc = cfg["pretrained_cfg"] model = PrithviMAE( img_size=pc["img_size"], num_frames=pc["num_frames"], patch_size=pc["patch_size"], in_chans=pc["in_chans"], embed_dim=pc["embed_dim"], depth=pc["depth"], num_heads=pc["num_heads"], decoder_embed_dim=pc["decoder_embed_dim"], decoder_depth=pc["decoder_depth"], decoder_num_heads=pc["decoder_num_heads"], mlp_ratio=pc["mlp_ratio"], coords_encoding=pc.get("coords_encoding", []), coords_scale_learn=pc.get("coords_scale_learn", False), mask_ratio=pc.get("mask_ratio", 0.75), ) weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt") if path else "" if weights_local and os.path.exists(weights_local): weights_path = weights_local else: print("[handler] downloading Prithvi weights from ibm-nasa-geospatial/Prithvi-EO-2.0-300M …") from huggingface_hub import hf_hub_download weights_path = hf_hub_download( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", "Prithvi_EO_V2_300M.pt", ) try: state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) except TypeError: state_dict = torch.load(weights_path, map_location="cpu") for k in list(state_dict.keys()): if "pos_embed" in k: del state_dict[k] model.load_state_dict(state_dict, strict=False) model.eval() # Keep on CPU: prithvi_mae's sincos pos_embed runs on CPU via numpy model = model.to(torch.device("cpu")) self._prithvi = model self._prithvi_embed_dim = pc["embed_dim"] print(f"[handler] Prithvi-EO-2.0-300M ready (embed_dim={pc['embed_dim']}, CPU)") except Exception as exc: print(f"[handler] Prithvi load failed: {exc}") self._prithvi = None def _load_terramind(self) -> None: try: # Import only the terramind submodule to trigger registry side-effects # without loading torchgeo-dependent backbones (avoids torchvision dep chain) import terratorch.models.backbones.terramind # noqa: F401 from terratorch.registry import BACKBONE_REGISTRY model = BACKBONE_REGISTRY.build( "terramind_v1_base", pretrained=True, modalities=["S2L2A"], ) model.eval().to(self._device) self._terramind = model self._terramind_embed_dim = 768 print(f"[handler] TerraMind-1.0-Base ready (embed_dim=768, {self._device})") except Exception as exc: print(f"[handler] TerraMind load failed: {exc}") self._terramind = None def __call__(self, data: dict) -> dict: model_name = data.get("model", "prithvi") raw = data.get("inputs", data) # Deserialise input if isinstance(raw, (bytes, bytearray)): try: arr = np.load(BytesIO(raw)).astype(np.float32) except Exception as exc: return {"error": f"cannot parse bytes: {exc}"} else: arr = np.array(raw, dtype=np.float32) if model_name == "prithvi": return self._run_prithvi(arr) elif model_name == "terramind": return self._run_terramind(arr) else: return {"error": f"unknown model: {model_name}"} def _run_prithvi(self, arr: np.ndarray) -> dict: if self._prithvi is None: return {"error": "Prithvi not loaded"} try: # Normalise shape → (B, C, T, H, W) if arr.ndim == 4: arr = arr[:, :, np.newaxis, :, :] # (B,C,H,W) → (B,C,1,H,W) elif arr.ndim == 5: arr = arr.transpose(0, 2, 1, 3, 4) # (B,T,C,H,W) → (B,C,T,H,W) tensor = torch.from_numpy(arr).to(torch.device("cpu")) with torch.no_grad(): features = self._prithvi.forward_features(tensor) last = features[-1] # (B, 1+N_tokens, embed_dim) emb = last[:, 1:, :].mean(dim=1) # mean-pool spatial tokens → (B, embed_dim) return {"embeddings": emb.cpu().numpy().tolist()} except Exception as exc: return {"error": f"Prithvi inference failed: {exc}"} def _run_terramind(self, arr: np.ndarray) -> dict: if self._terramind is None: return {"error": "TerraMind not loaded (terratorch unavailable)"} try: # TerraMind ViT encoder_embeddings expects {"S2L2A": tensor (B, C, H, W)} 4D # If caller sends (B, T, C, H, W), collapse time by taking the first frame if arr.ndim == 5: arr = arr[:, 0, :, :, :] # (B,T,C,H,W) → (B,C,H,W) tensor = torch.from_numpy(arr).to(self._device) with torch.no_grad(): out = self._terramind({"S2L2A": tensor}) # out may be a list of tensors, or a single tensor if isinstance(out, (list, tuple)): last = out[-1] # last encoder block output else: last = out # (B, N_tokens, embed_dim) → mean-pool → (B, embed_dim) if last.ndim == 3: emb = last.mean(dim=1) else: emb = last return {"embeddings": emb.cpu().numpy().tolist()} except Exception as exc: return {"error": f"TerraMind inference failed: {exc}"}