This an example of a single Pytorch .pt2 model archive that packages STAC Machine Learning Model_ (MLM) metadata.
MLM metadata is stored as YAML in the pt2's extra properties.
See https://docs.pytorch.org/docs/2.9/export/pt2_archive.html for more info on the pt2 archive spec.
See https://github.com/stac-extensions/mlm/blob/main/examples/torch/mlm-metadata.yaml for an example of MLM YAML metadata.
This model was exported with this script below. Download the original checkpoitn here: https://huggingface.co/torchgeo/ftw/blob/main/commercial/3-class/sentinel2_unet_effb3-5d591cbb.pth
from pathlib import Path
import torch
import torchvision.transforms.v2 as T
from stac_model.torch.export import save
import segmentation_models_pytorch as smp
path = "sentinel2_unet_effb3-ed36f465.pth"
ckpt = torch.load(path, map_location="cpu", weights_only=False)
hparams = ckpt["hyper_parameters"]
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
del state_dict["criterion.weight"]
model = smp.Unet(
encoder_name=hparams["backbone"],
encoder_weights=None,
in_channels=hparams["in_channels"],
classes=hparams["num_classes"],
)
model.load_state_dict(state_dict, strict=True)
transforms = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
T.Normalize(mean=[0.0], std=[3000.0])
)
metadata_path = "mlm.yaml"
save(
output_file=Path("model.pt2"),
input_shape=[-1, hparams["in_channels"], -1, -1],
model=model,
transforms=transforms,
metadata=metadata_path,
device="cpu",
dtype=torch.float32,
aoti_compile_and_package=False
)
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support