| """HF Inference Endpoint handler for Prithvi-EO-2.0-300M + TerraMind-1.0-Base. |
| |
| Deployed to pokkiri/eo-multibackbone-endpoint (framework="custom"). |
| |
| Request format: |
| {"model": "prithvi"|"terramind", "inputs": <array (B,T,C,H,W) or (B,C,H,W)>} |
| where inputs are normalised float32 arrays. |
| |
| For Prithvi: 6 channels in order [B02, B03, B04, B05, B06, B07], normalised |
| For TerraMind: 12 channels Sentinel-2 L2A bands, normalised |
| |
| Response format: |
| {"embeddings": [[float, ...], ...]} shape (B, embed_dim) |
| {"error": "message"} on failure |
| |
| Prithvi embed_dim = 1024 (mean-pooled spatial tokens from last encoder block) |
| TerraMind embed_dim = 768 (mean-pooled output tokens) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self._path = path |
| self._prithvi = None |
| self._terramind = None |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[handler] device: {self._device}") |
|
|
| self._load_prithvi(path) |
| self._load_terramind() |
|
|
| def _load_prithvi(self, path: str) -> None: |
| try: |
| |
| search_paths = [path, "/app", os.path.dirname(os.path.abspath(__file__))] |
| for sp in search_paths: |
| if sp and sp not in sys.path: |
| sys.path.insert(0, sp) |
|
|
| from prithvi_mae import PrithviMAE |
|
|
| |
| for cfg_dir in [path, "/app", os.path.dirname(os.path.abspath(__file__))]: |
| cfg_path = os.path.join(cfg_dir, "config.json") if cfg_dir else "config.json" |
| if os.path.exists(cfg_path): |
| break |
| else: |
| cfg_path = "config.json" |
|
|
| with open(cfg_path) as fh: |
| cfg = json.load(fh) |
| pc = cfg["pretrained_cfg"] |
|
|
| model = PrithviMAE( |
| img_size=pc["img_size"], |
| num_frames=pc["num_frames"], |
| patch_size=pc["patch_size"], |
| in_chans=pc["in_chans"], |
| embed_dim=pc["embed_dim"], |
| depth=pc["depth"], |
| num_heads=pc["num_heads"], |
| decoder_embed_dim=pc["decoder_embed_dim"], |
| decoder_depth=pc["decoder_depth"], |
| decoder_num_heads=pc["decoder_num_heads"], |
| mlp_ratio=pc["mlp_ratio"], |
| coords_encoding=pc.get("coords_encoding", []), |
| coords_scale_learn=pc.get("coords_scale_learn", False), |
| mask_ratio=pc.get("mask_ratio", 0.75), |
| ) |
|
|
| weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt") if path else "" |
| if weights_local and os.path.exists(weights_local): |
| weights_path = weights_local |
| else: |
| print("[handler] downloading Prithvi weights from ibm-nasa-geospatial/Prithvi-EO-2.0-300M …") |
| from huggingface_hub import hf_hub_download |
| weights_path = hf_hub_download( |
| "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", |
| "Prithvi_EO_V2_300M.pt", |
| ) |
|
|
| try: |
| state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) |
| except TypeError: |
| state_dict = torch.load(weights_path, map_location="cpu") |
|
|
| for k in list(state_dict.keys()): |
| if "pos_embed" in k: |
| del state_dict[k] |
|
|
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| |
| model = model.to(torch.device("cpu")) |
| self._prithvi = model |
| self._prithvi_embed_dim = pc["embed_dim"] |
| print(f"[handler] Prithvi-EO-2.0-300M ready (embed_dim={pc['embed_dim']}, CPU)") |
| except Exception as exc: |
| print(f"[handler] Prithvi load failed: {exc}") |
| self._prithvi = None |
|
|
| def _load_terramind(self) -> None: |
| try: |
| |
| |
| import terratorch.models.backbones.terramind |
| from terratorch.registry import BACKBONE_REGISTRY |
|
|
| model = BACKBONE_REGISTRY.build( |
| "terramind_v1_base", |
| pretrained=True, |
| modalities=["S2L2A"], |
| ) |
| model.eval().to(self._device) |
| self._terramind = model |
| self._terramind_embed_dim = 768 |
| print(f"[handler] TerraMind-1.0-Base ready (embed_dim=768, {self._device})") |
| except Exception as exc: |
| print(f"[handler] TerraMind load failed: {exc}") |
| self._terramind = None |
|
|
| def __call__(self, data: dict) -> dict: |
| model_name = data.get("model", "prithvi") |
| raw = data.get("inputs", data) |
|
|
| |
| if isinstance(raw, (bytes, bytearray)): |
| try: |
| arr = np.load(BytesIO(raw)).astype(np.float32) |
| except Exception as exc: |
| return {"error": f"cannot parse bytes: {exc}"} |
| else: |
| arr = np.array(raw, dtype=np.float32) |
|
|
| if model_name == "prithvi": |
| return self._run_prithvi(arr) |
| elif model_name == "terramind": |
| return self._run_terramind(arr) |
| else: |
| return {"error": f"unknown model: {model_name}"} |
|
|
| def _run_prithvi(self, arr: np.ndarray) -> dict: |
| if self._prithvi is None: |
| return {"error": "Prithvi not loaded"} |
|
|
| try: |
| |
| if arr.ndim == 4: |
| arr = arr[:, :, np.newaxis, :, :] |
| elif arr.ndim == 5: |
| arr = arr.transpose(0, 2, 1, 3, 4) |
|
|
| tensor = torch.from_numpy(arr).to(torch.device("cpu")) |
| with torch.no_grad(): |
| features = self._prithvi.forward_features(tensor) |
|
|
| last = features[-1] |
| emb = last[:, 1:, :].mean(dim=1) |
| return {"embeddings": emb.cpu().numpy().tolist()} |
| except Exception as exc: |
| return {"error": f"Prithvi inference failed: {exc}"} |
|
|
| def _run_terramind(self, arr: np.ndarray) -> dict: |
| if self._terramind is None: |
| return {"error": "TerraMind not loaded (terratorch unavailable)"} |
|
|
| try: |
| |
| |
| if arr.ndim == 5: |
| arr = arr[:, 0, :, :, :] |
|
|
| tensor = torch.from_numpy(arr).to(self._device) |
| with torch.no_grad(): |
| out = self._terramind({"S2L2A": tensor}) |
|
|
| |
| if isinstance(out, (list, tuple)): |
| last = out[-1] |
| else: |
| last = out |
|
|
| |
| if last.ndim == 3: |
| emb = last.mean(dim=1) |
| else: |
| emb = last |
|
|
| return {"embeddings": emb.cpu().numpy().tolist()} |
| except Exception as exc: |
| return {"error": f"TerraMind inference failed: {exc}"} |
|
|