| """HF Inference Endpoint custom handler for Prithvi-EO-2.0-300M. |
| |
| Uploaded to pokkiri/prithvi-eo-2-bench alongside prithvi_mae.py + config.json. |
| Weights are downloaded from the original IBM/NASA HF repo at startup (public model). |
| |
| Input (via inference_runner.py): |
| - application/octet-stream: numpy bytes of shape (B, T, C, H, W) [strategy 1] |
| - application/json: {"inputs": [[[[...]]]]} [strategy 2] |
| Prithvi uses 6 bands (B02 B03 B04 B05 B06 B07) in that order. |
| |
| Output: |
| {"embeddings": [[float, ...], ...]} β mean-pooled patch-token embedding per batch item. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| sys.path.insert(0, path) |
| from prithvi_mae import PrithviMAE |
|
|
| |
| cfg_path = os.path.join(path, "config.json") |
| with open(cfg_path) as fh: |
| cfg = json.load(fh) |
| pc = cfg["pretrained_cfg"] |
|
|
| self.model = PrithviMAE( |
| img_size=pc["img_size"], |
| num_frames=pc["num_frames"], |
| patch_size=pc["patch_size"], |
| in_chans=pc["in_chans"], |
| embed_dim=pc["embed_dim"], |
| depth=pc["depth"], |
| num_heads=pc["num_heads"], |
| decoder_embed_dim=pc["decoder_embed_dim"], |
| decoder_depth=pc["decoder_depth"], |
| decoder_num_heads=pc["decoder_num_heads"], |
| mlp_ratio=pc["mlp_ratio"], |
| coords_encoding=pc.get("coords_encoding", []), |
| coords_scale_learn=pc.get("coords_scale_learn", False), |
| mask_ratio=pc.get("mask_ratio", 0.75), |
| ) |
|
|
| |
| weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt") |
| if os.path.exists(weights_local): |
| weights_path = weights_local |
| else: |
| print("[handler] Prithvi_EO_V2_300M.pt not in repo dir β downloading from IBM/NASA HF β¦") |
| from huggingface_hub import hf_hub_download |
| weights_path = hf_hub_download( |
| "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", |
| "Prithvi_EO_V2_300M.pt", |
| ) |
| print(f"[handler] weights downloaded to {weights_path}") |
|
|
| try: |
| state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) |
| except TypeError: |
| |
| state_dict = torch.load(weights_path, map_location="cpu") |
|
|
| |
| for k in list(state_dict.keys()): |
| if "pos_embed" in k: |
| del state_dict[k] |
|
|
| self.model.load_state_dict(state_dict, strict=False) |
| self.model.eval() |
|
|
| |
| |
| |
| self.device = torch.device("cpu") |
| self.model = self.model.to(self.device) |
| print(f"[handler] Prithvi-EO-2.0-300M ready on {self.device}") |
|
|
| def __call__(self, data: dict) -> dict: |
| inputs = data.get("inputs", data) |
|
|
| |
| if isinstance(inputs, (bytes, bytearray)): |
| try: |
| arr = np.load(BytesIO(inputs)).astype(np.float32) |
| except Exception as exc: |
| return {"error": f"cannot parse input bytes as numpy array: {exc}"} |
| else: |
| arr = np.array(inputs, dtype=np.float32) |
|
|
| |
| |
| |
| if arr.ndim == 4: |
| |
| arr = arr[:, :, np.newaxis, :, :] |
| elif arr.ndim == 5: |
| |
| arr = arr.transpose(0, 2, 1, 3, 4) |
|
|
| tensor = torch.from_numpy(arr).to(self.device) |
|
|
| with torch.no_grad(): |
| features = self.model.forward_features(tensor) |
|
|
| |
| |
| last = features[-1] |
| embedding = last[:, 1:, :].mean(dim=1) |
|
|
| return {"embeddings": embedding.cpu().numpy().tolist()} |
|
|