ftw-v1.1-pt2 / README.md
rybavery's picture
Update README.md
b20b523 verified
metadata
license: cc-by-3.0
patch_size: 256
clip_size: 32
device: cuda
features:
  - s2med_harvest:B04
  - s2med_harvest:B03
  - s2med_harvest:B02
  - s2med_harvest:B08
  - s2med_planting:B04
  - s2med_planting:B03
  - s2med_planting:B02
  - s2med_planting:B08
labels:
  - non_field_background
  - field
  - field_boundaries
actor: semantic_segmentation
max_batch_size: 128
merge_mode: weighted_average

Model Card

Exported using the following code:

from pathlib import Path
import torch
import torchvision.transforms.v2 as T
from stac_model.torch.export import save
import segmentation_models_pytorch as smp


path = "FTW-Release-Full-3-class-unet-efficientnetb5-weight0.75-3xlonger.ckpt"
ckpt = torch.load(path, map_location="cpu", weights_only=False)
hparams = ckpt["hyper_parameters"]
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
del state_dict["criterion.weight"]
model = smp.Unet(
    encoder_name=hparams["backbone"],
    encoder_weights=None,
    in_channels=hparams["in_channels"],
    classes=hparams["num_classes"],
)
model.load_state_dict(state_dict, strict=True)

transforms = torch.nn.Sequential(
    torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
    T.Normalize(mean=[0.0], std=[3000.0])
)

save(
    output_file=Path("model.pt2"),
    input_shape=[-1, hparams["in_channels"], -1, -1],
    model=model,
    transforms=transforms,
    metadata=None,
    device="cpu",
    dtype=torch.float32,
    aoti_compile_and_package=False
)