"""HF Inference Endpoint custom handler for Prithvi-EO-2.0-300M. Uploaded to pokkiri/prithvi-eo-2-bench alongside prithvi_mae.py + config.json. Weights are downloaded from the original IBM/NASA HF repo at startup (public model). Input (via inference_runner.py): - application/octet-stream: numpy bytes of shape (B, T, C, H, W) [strategy 1] - application/json: {"inputs": [[[[...]]]]} [strategy 2] Prithvi uses 6 bands (B02 B03 B04 B05 B06 B07) in that order. Output: {"embeddings": [[float, ...], ...]} — mean-pooled patch-token embedding per batch item. """ from __future__ import annotations import json import os import sys from io import BytesIO from pathlib import Path import numpy as np import torch class EndpointHandler: def __init__(self, path: str = ""): # Make the repo directory importable so we can do `from prithvi_mae import PrithviMAE` sys.path.insert(0, path) from prithvi_mae import PrithviMAE # noqa: PLC0415 (inside __init__ by design) # Read architecture hyper-parameters from config.json cfg_path = os.path.join(path, "config.json") with open(cfg_path) as fh: cfg = json.load(fh) pc = cfg["pretrained_cfg"] self.model = PrithviMAE( img_size=pc["img_size"], num_frames=pc["num_frames"], patch_size=pc["patch_size"], in_chans=pc["in_chans"], embed_dim=pc["embed_dim"], depth=pc["depth"], num_heads=pc["num_heads"], decoder_embed_dim=pc["decoder_embed_dim"], decoder_depth=pc["decoder_depth"], decoder_num_heads=pc["decoder_num_heads"], mlp_ratio=pc["mlp_ratio"], coords_encoding=pc.get("coords_encoding", []), coords_scale_learn=pc.get("coords_scale_learn", False), mask_ratio=pc.get("mask_ratio", 0.75), ) # Load weights — try local path first, fall back to downloading from IBM/NASA HF repo weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt") if os.path.exists(weights_local): weights_path = weights_local else: print("[handler] Prithvi_EO_V2_300M.pt not in repo dir — downloading from IBM/NASA HF …") from huggingface_hub import hf_hub_download weights_path = hf_hub_download( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", "Prithvi_EO_V2_300M.pt", ) print(f"[handler] weights downloaded to {weights_path}") try: state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) except TypeError: # weights_only param not available in older PyTorch state_dict = torch.load(weights_path, map_location="cpu") # Discard fixed positional embeddings (interpolated from grid at runtime) for k in list(state_dict.keys()): if "pos_embed" in k: del state_dict[k] self.model.load_state_dict(state_dict, strict=False) self.model.eval() # Force CPU: prithvi_mae uses get_3d_sincos_pos_embed (numpy-sourced tensors) which # land on CPU at runtime. Running model on GPU then causes a cross-device error. # CPU is sufficient for 224×224 patch inference at benchmark scale. self.device = torch.device("cpu") self.model = self.model.to(self.device) print(f"[handler] Prithvi-EO-2.0-300M ready on {self.device}") def __call__(self, data: dict) -> dict: inputs = data.get("inputs", data) # Deserialise input if isinstance(inputs, (bytes, bytearray)): try: arr = np.load(BytesIO(inputs)).astype(np.float32) except Exception as exc: return {"error": f"cannot parse input bytes as numpy array: {exc}"} else: arr = np.array(inputs, dtype=np.float32) # Shape normalisation → Prithvi expects (B, C, T, H, W) # inference_runner sends (1, 1, C, H, W) for "B T C H W" models # meaning (batch=1, time=1, channels, H, W) — transpose axes 1 and 2 if arr.ndim == 4: # (B, C, H, W) → (B, C, 1, H, W) arr = arr[:, :, np.newaxis, :, :] elif arr.ndim == 5: # (B, T, C, H, W) → (B, C, T, H, W) arr = arr.transpose(0, 2, 1, 3, 4) tensor = torch.from_numpy(arr).to(self.device) with torch.no_grad(): features = self.model.forward_features(tensor) # features is a list of (B, 1+num_tokens, embed_dim) tensors, one per block. # Take the last (normalised) block, mean-pool over spatial tokens (skip CLS at 0). last = features[-1] # (B, 1+num_tokens, embed_dim) embedding = last[:, 1:, :].mean(dim=1) # (B, embed_dim) return {"embeddings": embedding.cpu().numpy().tolist()}