|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
This an example of a single Pytorch .pt2 model archive that packages [STAC Machine Learning Model](https://github.com/stac-extensions/mlm?tab=readme-ov-file#machine-learning-model-extension-specification)_ (MLM) metadata. |
|
|
|
|
|
MLM metadata is stored as YAML in the pt2's `extra` properties. |
|
|
|
|
|
See https://docs.pytorch.org/docs/2.9/export/pt2_archive.html for more info on the pt2 archive spec. |
|
|
|
|
|
See https://github.com/stac-extensions/mlm/blob/main/examples/torch/mlm-metadata.yaml for an example of MLM YAML metadata. |
|
|
|
|
|
This model was exported with this script below. Download the original checkpoitn here: https://huggingface.co/torchgeo/ftw/blob/main/commercial/3-class/sentinel2_unet_effb3-5d591cbb.pth |
|
|
|
|
|
```python |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import torchvision.transforms.v2 as T |
|
|
from stac_model.torch.export import save |
|
|
import segmentation_models_pytorch as smp |
|
|
|
|
|
|
|
|
path = "sentinel2_unet_effb3-ed36f465.pth" |
|
|
ckpt = torch.load(path, map_location="cpu", weights_only=False) |
|
|
hparams = ckpt["hyper_parameters"] |
|
|
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} |
|
|
del state_dict["criterion.weight"] |
|
|
model = smp.Unet( |
|
|
encoder_name=hparams["backbone"], |
|
|
encoder_weights=None, |
|
|
in_channels=hparams["in_channels"], |
|
|
classes=hparams["num_classes"], |
|
|
) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
transforms = torch.nn.Sequential( |
|
|
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), |
|
|
T.Normalize(mean=[0.0], std=[3000.0]) |
|
|
) |
|
|
metadata_path = "mlm.yaml" |
|
|
save( |
|
|
output_file=Path("model.pt2"), |
|
|
input_shape=[-1, hparams["in_channels"], -1, -1], |
|
|
model=model, |
|
|
transforms=transforms, |
|
|
metadata=metadata_path, |
|
|
device="cpu", |
|
|
dtype=torch.float32, |
|
|
aoti_compile_and_package=False |
|
|
) |
|
|
|
|
|
``` |