isaaccorley commited on
Commit
eb2e36c
·
verified ·
1 Parent(s): 4c939c3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +194 -3
README.md CHANGED
@@ -1,3 +1,194 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ The following provides minimal code for loading and exporting the [Tessera geospatial foundation model](https://github.com/ucam-eo/tessera). The original checkpoint file `best_model_fsdp_20250427_084307.pt` hosted on [Google Drive](https://drive.google.com/drive/folders/18RPptbUkCIgUfw1aMdMeOrFML_ZVMszn?usp=sharing) was ~7GB however repackaging only the model weights results in a 337MB checkpoint file `model.pt`. Further, the model is also exported with torch.export to `model_exported_program.pt2` so that the model code itself is not needed to run inference.
6
+
7
+ ```python
8
+ import torch
9
+
10
+
11
+ class AttentionPooling(torch.nn.Module):
12
+ def __init__(self, input_dim):
13
+ super().__init__()
14
+ self.query = torch.nn.Linear(input_dim, 1)
15
+
16
+ def forward(self, x):
17
+ # x: (B, seq_len, dim)
18
+ w = torch.softmax(self.query(x), dim=1) # (B, seq_len, 1)
19
+ return (w * x).sum(dim=1)
20
+
21
+
22
+ class TemporalAwarePooling(torch.nn.Module):
23
+ def __init__(self, input_dim):
24
+ super().__init__()
25
+ self.query = torch.nn.Linear(input_dim, 1)
26
+ self.temporal_context = torch.nn.GRU(input_dim, input_dim, batch_first=True)
27
+
28
+ def forward(self, x):
29
+ # First capture temporal context through RNN
30
+ x_context, _ = self.temporal_context(x)
31
+ # Then calculate attention weights
32
+ w = torch.softmax(self.query(x_context), dim=1)
33
+ return (w * x).sum(dim=1)
34
+
35
+
36
+ class TemporalEncoding(torch.nn.Module):
37
+ def __init__(self, d_model, num_freqs=64):
38
+ super().__init__()
39
+ self.num_freqs = num_freqs
40
+ self.d_model = d_model
41
+
42
+ # Learnable frequency parameters (more flexible than fixed frequencies)
43
+ self.freqs = torch.nn.Parameter(torch.exp(torch.linspace(0, np.log(365.0), num_freqs)))
44
+
45
+ # Project Fourier features to the target dimension through a linear layer
46
+ self.proj = torch.nn.Linear(2 * num_freqs, d_model)
47
+ self.phase = torch.nn.Parameter(torch.zeros(1, 1, d_model)) # Learnable phase offset
48
+
49
+ def forward(self, doy):
50
+ # doy: (B, seq_len, 1)
51
+ t = doy / 365.0 * 2 * torch.pi # Normalize to the 0-2π range
52
+
53
+ # Generate multi-frequency sine/cosine features
54
+ t_scaled = t * self.freqs.view(1, 1, -1) # (B, seq_len, num_freqs)
55
+ sin = torch.sin(t_scaled + self.phase[..., :self.num_freqs])
56
+ cos = torch.cos(t_scaled + self.phase[..., self.num_freqs:2*self.num_freqs])
57
+
58
+ # Concatenate and project to the target dimension
59
+ encoding = torch.cat([sin, cos], dim=-1) # (B, seq_len, 2*num_freqs)
60
+ return self.proj(encoding) # (B, seq_len, d_model)
61
+
62
+
63
+ class TemporalPositionalEncoder(torch.nn.Module):
64
+ def __init__(self, d_model):
65
+ super().__init__()
66
+ self.d_model = d_model
67
+
68
+ def forward(self, doy):
69
+ # doy: [B, T] tensor containing DOY values (0-365)
70
+ position = doy.unsqueeze(-1).float() # Ensure float type
71
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float) * -(torch.log(torch.tensor(10000.0)) / self.d_model))
72
+ div_term = div_term.to(doy.device)
73
+
74
+ pe = torch.zeros(doy.shape[0], doy.shape[1], self.d_model, device=doy.device)
75
+ pe[:, :, 0::2] = torch.sin(position * div_term)
76
+ pe[:, :, 1::2] = torch.cos(position * div_term)
77
+ return pe
78
+
79
+
80
+ class TransformerEncoder(torch.nn.Module):
81
+ def __init__(self, band_num, latent_dim, nhead=8, num_encoder_layers=4,
82
+ dim_feedforward=512, dropout=0.1, max_seq_len=20):
83
+ super().__init__()
84
+ # Total input dimension: bands
85
+ input_dim = band_num
86
+
87
+ # Embedding to increase dimension
88
+ self.embedding = torch.nn.Sequential(
89
+ torch.nn.Linear(input_dim, latent_dim*4),
90
+ torch.nn.ReLU(),
91
+ torch.nn.Linear(latent_dim*4, latent_dim*4)
92
+ )
93
+
94
+ # Temporal Encoder for DOY as position encoding
95
+ self.temporal_encoder = TemporalPositionalEncoder(d_model=latent_dim*4)
96
+
97
+ # Transformer Encoder Layer
98
+ encoder_layer = torch.nn.TransformerEncoderLayer(
99
+ d_model=latent_dim*4,
100
+ nhead=nhead,
101
+ dim_feedforward=dim_feedforward,
102
+ dropout=dropout,
103
+ activation="relu",
104
+ batch_first=True,
105
+ )
106
+ self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
107
+
108
+ # Temporal Aware Pooling
109
+ self.attn_pool = TemporalAwarePooling(latent_dim*4)
110
+
111
+ def forward(self, x):
112
+ # x: (B, seq_len, 10 bands + 1 doy)
113
+ # Split bands and doy
114
+ bands = x[:, :, :-1] # All columns except last one
115
+ doy = x[:, :, -1] # Last column is DOY
116
+ # Embedding of bands
117
+ bands_embedded = self.embedding(bands) # (B, seq_len, latent_dim*4)
118
+ temporal_encoding = self.temporal_encoder(doy)
119
+ # Add temporal encoding to embedded bands (instead of random positional encoding)
120
+ x = bands_embedded + temporal_encoding
121
+ x = self.transformer_encoder(x)
122
+ x = self.attn_pool(x)
123
+ return x
124
+
125
+
126
+ class Tessera(torch.nn.Module):
127
+ def __init__(self):
128
+ super().__init__()
129
+ self.s2_backbone = TransformerEncoder(
130
+ band_num=10,
131
+ latent_dim=128,
132
+ nhead=8,
133
+ num_encoder_layers=8,
134
+ dim_feedforward=4096,
135
+ dropout=0.1,
136
+ max_seq_len=40
137
+ )
138
+ self.s1_backbone = TransformerEncoder(
139
+ band_num=2,
140
+ latent_dim=128,
141
+ nhead=8,
142
+ num_encoder_layers=8,
143
+ dim_feedforward=4096,
144
+ dropout=0.1,
145
+ max_seq_len=40
146
+ )
147
+ self.dim_reducer = torch.nn.Sequential(torch.nn.Linear(128 * 8, 128))
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ """
151
+ Args:
152
+ x: tensor of shape(b, t, c) where c=14, the first 11 channels are
153
+ sentinel-2 (10 bands + 1 doy features) and the last 3 channels are
154
+ sentinel-1 (2 bands + 1 doy features)
155
+ """
156
+ assert x.shape[-1] == 14
157
+ s2_x, s1_x = x[..., :11], x[..., 11:]
158
+ s2_feat = self.s2_backbone(s2_x) # (b, d)
159
+ s1_feat = self.s1_backbone(s1_x) # (b, d)
160
+ fused = torch.cat([s2_feat, s1_feat], dim=-1) # (b, 2d)
161
+ fused = self.dim_reducer(fused) # (b, 128)
162
+ return fused
163
+
164
+
165
+ # Load the pretrained model for inference only without the projection using the pretrained config
166
+ model = Tessera()
167
+ model.eval()
168
+
169
+ b, t = 2, 10
170
+ s2 = torch.randn(b, t, 10)
171
+ s2_doy = torch.randint(1, 365, (b, t, 1))
172
+ s1 = torch.randn(b, t, 2)
173
+ s1_doy = torch.randint(1, 365, (b, t, 1))
174
+
175
+ x = torch.cat([s2, s2_doy, s1, s1_doy], dim=-1)
176
+ print(model(x).shape)
177
+
178
+ # Load and extract only the model state dict then save to model.pt
179
+ path = "best_model_fsdp_20250427_084307.pt"
180
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
181
+ modules = ["s2_backbone", "s1_backbone", "dim_reducer"]
182
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model_state_dict"].items()}
183
+ state_dict = {k: v for k, v in state_dict.items() if k.split(".")[0] in modules}
184
+ model.load_state_dict(state_dict, strict=True)
185
+ torch.save(model.state_dict(), "model.pt")
186
+
187
+ # Export the model and save to model_exported_program.pt2
188
+ from torch.export.dynamic_shapes import Dim
189
+
190
+ example_inputs = torch.randn(1, 10, 14)
191
+ dims = (Dim.AUTO, Dim.AUTO, 14)
192
+ model_program = torch.export.export(mod=model, args=(example_inputs,), dynamic_shapes={"x": dims})
193
+ torch.export.save(model_program, 'model_exported_program.pt2')
194
+ ```