Update README.md
Browse files
README.md
CHANGED
|
@@ -23,3 +23,42 @@ merge_mode: weighted_average
|
|
| 23 |
|
| 24 |
# Model Card
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Model Card
|
| 25 |
|
| 26 |
+
Exported using the following code:
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
import torch
|
| 31 |
+
import torchvision.transforms.v2 as T
|
| 32 |
+
from stac_model.torch.export import save
|
| 33 |
+
import segmentation_models_pytorch as smp
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
path = "FTW-Release-Full-3-class-unet-efficientnetb5-weight0.75-3xlonger.ckpt"
|
| 37 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 38 |
+
hparams = ckpt["hyper_parameters"]
|
| 39 |
+
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
|
| 40 |
+
del state_dict["criterion.weight"]
|
| 41 |
+
model = smp.Unet(
|
| 42 |
+
encoder_name=hparams["backbone"],
|
| 43 |
+
encoder_weights=None,
|
| 44 |
+
in_channels=hparams["in_channels"],
|
| 45 |
+
classes=hparams["num_classes"],
|
| 46 |
+
)
|
| 47 |
+
model.load_state_dict(state_dict, strict=True)
|
| 48 |
+
|
| 49 |
+
transforms = torch.nn.Sequential(
|
| 50 |
+
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
| 51 |
+
T.Normalize(mean=[0.0], std=[3000.0])
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
save(
|
| 55 |
+
output_file=Path("model.pt2"),
|
| 56 |
+
input_shape=[-1, hparams["in_channels"], -1, -1],
|
| 57 |
+
model=model,
|
| 58 |
+
transforms=transforms,
|
| 59 |
+
metadata=None,
|
| 60 |
+
device="cpu",
|
| 61 |
+
dtype=torch.float32,
|
| 62 |
+
aoti_compile_and_package=False
|
| 63 |
+
)
|
| 64 |
+
```
|