pokkiri's picture
fix: terramind working - correct input shape (4D) and import path
f4bb0f7 verified
"""HF Inference Endpoint handler for Prithvi-EO-2.0-300M + TerraMind-1.0-Base.
Deployed to pokkiri/eo-multibackbone-endpoint (framework="custom").
Request format:
{"model": "prithvi"|"terramind", "inputs": <array (B,T,C,H,W) or (B,C,H,W)>}
where inputs are normalised float32 arrays.
For Prithvi: 6 channels in order [B02, B03, B04, B05, B06, B07], normalised
For TerraMind: 12 channels Sentinel-2 L2A bands, normalised
Response format:
{"embeddings": [[float, ...], ...]} shape (B, embed_dim)
{"error": "message"} on failure
Prithvi embed_dim = 1024 (mean-pooled spatial tokens from last encoder block)
TerraMind embed_dim = 768 (mean-pooled output tokens)
"""
from __future__ import annotations
import json
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import torch
class EndpointHandler:
def __init__(self, path: str = ""):
self._path = path
self._prithvi = None
self._terramind = None
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[handler] device: {self._device}")
self._load_prithvi(path)
self._load_terramind()
def _load_prithvi(self, path: str) -> None:
try:
# prithvi_mae.py lives alongside handler.py in /app/ when path is empty
search_paths = [path, "/app", os.path.dirname(os.path.abspath(__file__))]
for sp in search_paths:
if sp and sp not in sys.path:
sys.path.insert(0, sp)
from prithvi_mae import PrithviMAE # noqa: PLC0415
# Try config.json in path first, then /app/
for cfg_dir in [path, "/app", os.path.dirname(os.path.abspath(__file__))]:
cfg_path = os.path.join(cfg_dir, "config.json") if cfg_dir else "config.json"
if os.path.exists(cfg_path):
break
else:
cfg_path = "config.json"
with open(cfg_path) as fh:
cfg = json.load(fh)
pc = cfg["pretrained_cfg"]
model = PrithviMAE(
img_size=pc["img_size"],
num_frames=pc["num_frames"],
patch_size=pc["patch_size"],
in_chans=pc["in_chans"],
embed_dim=pc["embed_dim"],
depth=pc["depth"],
num_heads=pc["num_heads"],
decoder_embed_dim=pc["decoder_embed_dim"],
decoder_depth=pc["decoder_depth"],
decoder_num_heads=pc["decoder_num_heads"],
mlp_ratio=pc["mlp_ratio"],
coords_encoding=pc.get("coords_encoding", []),
coords_scale_learn=pc.get("coords_scale_learn", False),
mask_ratio=pc.get("mask_ratio", 0.75),
)
weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt") if path else ""
if weights_local and os.path.exists(weights_local):
weights_path = weights_local
else:
print("[handler] downloading Prithvi weights from ibm-nasa-geospatial/Prithvi-EO-2.0-300M …")
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
"Prithvi_EO_V2_300M.pt",
)
try:
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
except TypeError:
state_dict = torch.load(weights_path, map_location="cpu")
for k in list(state_dict.keys()):
if "pos_embed" in k:
del state_dict[k]
model.load_state_dict(state_dict, strict=False)
model.eval()
# Keep on CPU: prithvi_mae's sincos pos_embed runs on CPU via numpy
model = model.to(torch.device("cpu"))
self._prithvi = model
self._prithvi_embed_dim = pc["embed_dim"]
print(f"[handler] Prithvi-EO-2.0-300M ready (embed_dim={pc['embed_dim']}, CPU)")
except Exception as exc:
print(f"[handler] Prithvi load failed: {exc}")
self._prithvi = None
def _load_terramind(self) -> None:
try:
# Import only the terramind submodule to trigger registry side-effects
# without loading torchgeo-dependent backbones (avoids torchvision dep chain)
import terratorch.models.backbones.terramind # noqa: F401
from terratorch.registry import BACKBONE_REGISTRY
model = BACKBONE_REGISTRY.build(
"terramind_v1_base",
pretrained=True,
modalities=["S2L2A"],
)
model.eval().to(self._device)
self._terramind = model
self._terramind_embed_dim = 768
print(f"[handler] TerraMind-1.0-Base ready (embed_dim=768, {self._device})")
except Exception as exc:
print(f"[handler] TerraMind load failed: {exc}")
self._terramind = None
def __call__(self, data: dict) -> dict:
model_name = data.get("model", "prithvi")
raw = data.get("inputs", data)
# Deserialise input
if isinstance(raw, (bytes, bytearray)):
try:
arr = np.load(BytesIO(raw)).astype(np.float32)
except Exception as exc:
return {"error": f"cannot parse bytes: {exc}"}
else:
arr = np.array(raw, dtype=np.float32)
if model_name == "prithvi":
return self._run_prithvi(arr)
elif model_name == "terramind":
return self._run_terramind(arr)
else:
return {"error": f"unknown model: {model_name}"}
def _run_prithvi(self, arr: np.ndarray) -> dict:
if self._prithvi is None:
return {"error": "Prithvi not loaded"}
try:
# Normalise shape → (B, C, T, H, W)
if arr.ndim == 4:
arr = arr[:, :, np.newaxis, :, :] # (B,C,H,W) → (B,C,1,H,W)
elif arr.ndim == 5:
arr = arr.transpose(0, 2, 1, 3, 4) # (B,T,C,H,W) → (B,C,T,H,W)
tensor = torch.from_numpy(arr).to(torch.device("cpu"))
with torch.no_grad():
features = self._prithvi.forward_features(tensor)
last = features[-1] # (B, 1+N_tokens, embed_dim)
emb = last[:, 1:, :].mean(dim=1) # mean-pool spatial tokens → (B, embed_dim)
return {"embeddings": emb.cpu().numpy().tolist()}
except Exception as exc:
return {"error": f"Prithvi inference failed: {exc}"}
def _run_terramind(self, arr: np.ndarray) -> dict:
if self._terramind is None:
return {"error": "TerraMind not loaded (terratorch unavailable)"}
try:
# TerraMind ViT encoder_embeddings expects {"S2L2A": tensor (B, C, H, W)} 4D
# If caller sends (B, T, C, H, W), collapse time by taking the first frame
if arr.ndim == 5:
arr = arr[:, 0, :, :, :] # (B,T,C,H,W) → (B,C,H,W)
tensor = torch.from_numpy(arr).to(self._device)
with torch.no_grad():
out = self._terramind({"S2L2A": tensor})
# out may be a list of tensors, or a single tensor
if isinstance(out, (list, tuple)):
last = out[-1] # last encoder block output
else:
last = out
# (B, N_tokens, embed_dim) → mean-pool → (B, embed_dim)
if last.ndim == 3:
emb = last.mean(dim=1)
else:
emb = last
return {"embeddings": emb.cpu().numpy().tolist()}
except Exception as exc:
return {"error": f"TerraMind inference failed: {exc}"}