File size: 4,960 Bytes
0d1bd4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801c795
 
 
 
0d1bd4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""HF Inference Endpoint custom handler for Prithvi-EO-2.0-300M.

Uploaded to pokkiri/prithvi-eo-2-bench alongside prithvi_mae.py + config.json.
Weights are downloaded from the original IBM/NASA HF repo at startup (public model).

Input (via inference_runner.py):
    - application/octet-stream: numpy bytes of shape (B, T, C, H, W)  [strategy 1]
    - application/json:         {"inputs": [[[[...]]]]}  [strategy 2]
    Prithvi uses 6 bands (B02 B03 B04 B05 B06 B07) in that order.

Output:
    {"embeddings": [[float, ...], ...]}  — mean-pooled patch-token embedding per batch item.
"""

from __future__ import annotations

import json
import os
import sys
from io import BytesIO
from pathlib import Path

import numpy as np
import torch


class EndpointHandler:
    def __init__(self, path: str = ""):
        # Make the repo directory importable so we can do `from prithvi_mae import PrithviMAE`
        sys.path.insert(0, path)
        from prithvi_mae import PrithviMAE  # noqa: PLC0415 (inside __init__ by design)

        # Read architecture hyper-parameters from config.json
        cfg_path = os.path.join(path, "config.json")
        with open(cfg_path) as fh:
            cfg = json.load(fh)
        pc = cfg["pretrained_cfg"]

        self.model = PrithviMAE(
            img_size=pc["img_size"],
            num_frames=pc["num_frames"],
            patch_size=pc["patch_size"],
            in_chans=pc["in_chans"],
            embed_dim=pc["embed_dim"],
            depth=pc["depth"],
            num_heads=pc["num_heads"],
            decoder_embed_dim=pc["decoder_embed_dim"],
            decoder_depth=pc["decoder_depth"],
            decoder_num_heads=pc["decoder_num_heads"],
            mlp_ratio=pc["mlp_ratio"],
            coords_encoding=pc.get("coords_encoding", []),
            coords_scale_learn=pc.get("coords_scale_learn", False),
            mask_ratio=pc.get("mask_ratio", 0.75),
        )

        # Load weights — try local path first, fall back to downloading from IBM/NASA HF repo
        weights_local = os.path.join(path, "Prithvi_EO_V2_300M.pt")
        if os.path.exists(weights_local):
            weights_path = weights_local
        else:
            print("[handler] Prithvi_EO_V2_300M.pt not in repo dir — downloading from IBM/NASA HF …")
            from huggingface_hub import hf_hub_download
            weights_path = hf_hub_download(
                "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
                "Prithvi_EO_V2_300M.pt",
            )
            print(f"[handler] weights downloaded to {weights_path}")

        try:
            state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
        except TypeError:
            # weights_only param not available in older PyTorch
            state_dict = torch.load(weights_path, map_location="cpu")

        # Discard fixed positional embeddings (interpolated from grid at runtime)
        for k in list(state_dict.keys()):
            if "pos_embed" in k:
                del state_dict[k]

        self.model.load_state_dict(state_dict, strict=False)
        self.model.eval()

        # Force CPU: prithvi_mae uses get_3d_sincos_pos_embed (numpy-sourced tensors) which
        # land on CPU at runtime. Running model on GPU then causes a cross-device error.
        # CPU is sufficient for 224×224 patch inference at benchmark scale.
        self.device = torch.device("cpu")
        self.model = self.model.to(self.device)
        print(f"[handler] Prithvi-EO-2.0-300M ready on {self.device}")

    def __call__(self, data: dict) -> dict:
        inputs = data.get("inputs", data)

        # Deserialise input
        if isinstance(inputs, (bytes, bytearray)):
            try:
                arr = np.load(BytesIO(inputs)).astype(np.float32)
            except Exception as exc:
                return {"error": f"cannot parse input bytes as numpy array: {exc}"}
        else:
            arr = np.array(inputs, dtype=np.float32)

        # Shape normalisation → Prithvi expects (B, C, T, H, W)
        # inference_runner sends (1, 1, C, H, W) for "B T C H W" models
        # meaning (batch=1, time=1, channels, H, W) — transpose axes 1 and 2
        if arr.ndim == 4:
            # (B, C, H, W) → (B, C, 1, H, W)
            arr = arr[:, :, np.newaxis, :, :]
        elif arr.ndim == 5:
            # (B, T, C, H, W) → (B, C, T, H, W)
            arr = arr.transpose(0, 2, 1, 3, 4)

        tensor = torch.from_numpy(arr).to(self.device)

        with torch.no_grad():
            features = self.model.forward_features(tensor)

        # features is a list of (B, 1+num_tokens, embed_dim) tensors, one per block.
        # Take the last (normalised) block, mean-pool over spatial tokens (skip CLS at 0).
        last = features[-1]          # (B, 1+num_tokens, embed_dim)
        embedding = last[:, 1:, :].mean(dim=1)   # (B, embed_dim)

        return {"embeddings": embedding.cpu().numpy().tolist()}