File size: 4,960 Bytes
0d1bd4a 801c795 0d1bd4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | """HF Inference Endpoint custom handler for Prithvi-EO-2.0-300M.
Uploaded to pokkiri/prithvi-eo-2-bench alongside prithvi_mae.py + config.json.
Weights are downloaded from the original IBM/NASA HF repo at startup (public model).
Input (via inference_runner.py):
- application/octet-stream: numpy bytes of shape (B, T, C, H, W) [strategy 1]
- application/json: {"inputs": [[[[...]]]]} [strategy 2]
Prithvi uses 6 bands (B02 B03 B04 B05 B06 B07) in that order.
Output:
{"embeddings": [[float, ...], ...]} — mean-pooled patch-token embedding per batch item.
"""
from __future__ import annotations
import json
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import torch
class EndpointHandler:
def __init__(self, path: str = ""):
# Make the repo directory importable so we can do `from prithvi_mae import PrithviMAE`
sys.path.insert(0, path)
from prithvi_mae import PrithviMAE # noqa: PLC0415 (inside __init__ by design)
# Read architecture hyper-parameters from config.json
cfg_path = os.path.join(path, "config.json")
with open(cfg_path) as fh:
cfg = json.load(fh)
pc = cfg["pretrained_cfg"]
self.model = PrithviMAE(
img_size=pc["img_size"],
num_frames=pc["num_frames"],
patch_size=pc["patch_size"],
in_chans=pc["in_chans"],
embed_dim=pc["embed_dim"],
depth=pc["depth"],
num_heads=pc["num_heads"],
decoder_embed_dim=pc["decoder_embed_dim"],
decoder_depth=pc["decoder_depth"],
decoder_num_heads=pc["decoder_num_heads"],
mlp_ratio=pc["mlp_ratio"],
coords_encoding=pc.get("coords_encoding", []),
coords_scale_learn=pc.get("coords_scale_learn", False),
mask_ratio=pc.get("mask_ratio", 0.75),
)
# Load weights — try local path first, fall back to downloading from IBM/NASA HF repo
weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt")
if os.path.exists(weights_local):
weights_path = weights_local
else:
print("[handler] Prithvi_EO_V2_300M.pt not in repo dir — downloading from IBM/NASA HF …")
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
"Prithvi_EO_V2_300M.pt",
)
print(f"[handler] weights downloaded to {weights_path}")
try:
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
except TypeError:
# weights_only param not available in older PyTorch
state_dict = torch.load(weights_path, map_location="cpu")
# Discard fixed positional embeddings (interpolated from grid at runtime)
for k in list(state_dict.keys()):
if "pos_embed" in k:
del state_dict[k]
self.model.load_state_dict(state_dict, strict=False)
self.model.eval()
# Force CPU: prithvi_mae uses get_3d_sincos_pos_embed (numpy-sourced tensors) which
# land on CPU at runtime. Running model on GPU then causes a cross-device error.
# CPU is sufficient for 224×224 patch inference at benchmark scale.
self.device = torch.device("cpu")
self.model = self.model.to(self.device)
print(f"[handler] Prithvi-EO-2.0-300M ready on {self.device}")
def __call__(self, data: dict) -> dict:
inputs = data.get("inputs", data)
# Deserialise input
if isinstance(inputs, (bytes, bytearray)):
try:
arr = np.load(BytesIO(inputs)).astype(np.float32)
except Exception as exc:
return {"error": f"cannot parse input bytes as numpy array: {exc}"}
else:
arr = np.array(inputs, dtype=np.float32)
# Shape normalisation → Prithvi expects (B, C, T, H, W)
# inference_runner sends (1, 1, C, H, W) for "B T C H W" models
# meaning (batch=1, time=1, channels, H, W) — transpose axes 1 and 2
if arr.ndim == 4:
# (B, C, H, W) → (B, C, 1, H, W)
arr = arr[:, :, np.newaxis, :, :]
elif arr.ndim == 5:
# (B, T, C, H, W) → (B, C, T, H, W)
arr = arr.transpose(0, 2, 1, 3, 4)
tensor = torch.from_numpy(arr).to(self.device)
with torch.no_grad():
features = self.model.forward_features(tensor)
# features is a list of (B, 1+num_tokens, embed_dim) tensors, one per block.
# Take the last (normalised) block, mean-pool over spatial tokens (skip CLS at 0).
last = features[-1] # (B, 1+num_tokens, embed_dim)
embedding = last[:, 1:, :].mean(dim=1) # (B, embed_dim)
return {"embeddings": embedding.cpu().numpy().tolist()}
|