rybavery commited on
Commit
d3856cd
·
verified ·
1 Parent(s): a28b2da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -1
README.md CHANGED
@@ -14,5 +14,76 @@ max_batch_size: 64
14
  merge_mode: weighted_average
15
  ---
16
 
17
- # Model Card
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
14
  merge_mode: weighted_average
15
  ---
16
 
17
+ First run the following to setup the environment and get the official model code
18
+
19
+ ```bash
20
+ # Clone the official repo
21
+ git clone git@github.com:facebookresearch/HighResCanopyHeight.git
22
+
23
+ # Install dependencies
24
+ pip install stac-model[torch]
25
+
26
+ # Download the official pretrained checkpoints
27
+ mkdir checkpoints && aws s3 --no-sign-request sync s3://dataforgood-fb-data/forests/v1/models/saved_checkpoints/ checkpoints/
28
+ ```
29
+
30
+ Export the model using the following:
31
+
32
+ ```python
33
+ from pathlib import Path
34
+ import sys
35
+ sys.path.append("HighResCanopyHeight")
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torchvision.transforms.v2 as T
40
+ from stac_model.torch.export import export, package
41
+
42
+ import src.transforms
43
+ from inference import SSLAE
44
+
45
+
46
+ # Create model and load checkpoint
47
+ class TreeCanopyHeightModel(nn.Module):
48
+ def __init__(self, classify=True, huge=True):
49
+ super().__init__()
50
+ self.model = SSLAE(pretrained=None, classify=classify, huge=huge, n_bins=256)
51
+
52
+ def forward(self, x):
53
+ outputs = self.model(x)
54
+ pred = 10 * outputs + 0.001
55
+ return pred.relu()
56
+
57
+ path = "checkpoints/SSLhuge_satellite.pth"
58
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
59
+ state_dict = {f"model.{k}": v for k, v in ckpt["state_dict"].items()}
60
+ model = TreeCanopyHeightModel()
61
+ model.load_state_dict(state_dict)
62
+
63
+ # Create exportable transforms
64
+ original_transform = src.transforms.SSLNorm().Trans
65
+ norm = original_transform.transforms[-1]
66
+
67
+ transforms = nn.Sequential(
68
+ T.Normalize(mean=[0], std=[255]), # replace ToTensor() with normalize to 0-1
69
+ T.Normalize(mean=norm.mean, std=norm.std)
70
+ )
71
+
72
+ # Export and save to pt2
73
+ model_program, transforms_program = export(
74
+ input_shape=[-1, 3, 224, 224],
75
+ model=model,
76
+ transforms=transforms,
77
+ device="cpu",
78
+ dtype=torch.float32,
79
+ )
80
+ package(
81
+ output_file=Path("model.pt2"),
82
+ model_program=model_program,
83
+ transforms_program=transforms_program,
84
+ metadata_properties=None,
85
+ aoti_compile_and_package=False
86
+ )
87
+ ```
88
+
89