rybavery commited on
Commit
b20b523
·
verified ·
1 Parent(s): fc8e551

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +39 -0
README.md CHANGED
@@ -23,3 +23,42 @@ merge_mode: weighted_average
23
 
24
  # Model Card
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Model Card
25
 
26
+ Exported using the following code:
27
+
28
+ ```python
29
+ from pathlib import Path
30
+ import torch
31
+ import torchvision.transforms.v2 as T
32
+ from stac_model.torch.export import save
33
+ import segmentation_models_pytorch as smp
34
+
35
+
36
+ path = "FTW-Release-Full-3-class-unet-efficientnetb5-weight0.75-3xlonger.ckpt"
37
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
38
+ hparams = ckpt["hyper_parameters"]
39
+ state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
40
+ del state_dict["criterion.weight"]
41
+ model = smp.Unet(
42
+ encoder_name=hparams["backbone"],
43
+ encoder_weights=None,
44
+ in_channels=hparams["in_channels"],
45
+ classes=hparams["num_classes"],
46
+ )
47
+ model.load_state_dict(state_dict, strict=True)
48
+
49
+ transforms = torch.nn.Sequential(
50
+ torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
51
+ T.Normalize(mean=[0.0], std=[3000.0])
52
+ )
53
+
54
+ save(
55
+ output_file=Path("model.pt2"),
56
+ input_shape=[-1, hparams["in_channels"], -1, -1],
57
+ model=model,
58
+ transforms=transforms,
59
+ metadata=None,
60
+ device="cpu",
61
+ dtype=torch.float32,
62
+ aoti_compile_and_package=False
63
+ )
64
+ ```