niobures commited on
Commit
b090fcc
·
verified ·
1 Parent(s): 35d10aa

Presto (code, colab, google_earth, paper)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 20240418-Presto-TFWorkingGroup.pptx filter=lfs diff=lfs merge=lfs -text
37
+ code/LEM[[:space:]]Data[[:space:]]Processing.pptx filter=lfs diff=lfs merge=lfs -text
38
+ Lightweight,[[:space:]]Pre-trained[[:space:]]Transformers[[:space:]]for[[:space:]]Remote[[:space:]]Sensing[[:space:]]Timeseries.pdf filter=lfs diff=lfs merge=lfs -text
20240418-Presto-TFWorkingGroup.pptx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7ab6859bee5653ed738b21cff701605c71768e453e2c6c81ac6c7609dbf4d39
3
+ size 21067278
Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8324993896a3c346def6a5a0f923561d44f1c3aed270ff19bb29ac45986128f2
3
+ size 3957097
code/LEM Data Processing.pptx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22e30405ed3bb718599afa342ccb857a41d64d7fb58c2bd582b45078b2b65b87
3
+ size 1296790
code/presto [DurojaiyeAbisoye] +6 -5.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5015cb3128495d8c40b2d77a1ac5d93060469a3d6ff662cd57e4ee0668ac649
3
+ size 35240122
code/presto-ndws.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ad48fd3638c91ad7b6856cd953dcba8065d98ba9a940de14784b9f654be0d0e
3
+ size 263605
code/presto-worldcereal.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:466dedf973084219798dcc0b5a4f24241f39bd7af20a6d40acc838275435c4ec
3
+ size 71758689
code/presto.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d299c366a348393c54054f3b7d4994cd0ce3facafcd9cbf4b1adb6f26095fb7
3
+ size 45533344
colab/1_Presto_to_VertexAI.ipynb ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "_SVA9v_JTq_-"
7
+ },
8
+ "source": [
9
+ "# 1. Presto to Vertex AI\n",
10
+ "\n",
11
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/nasaharvest/presto/blob/main/deploy/1_Presto_to_VertexAI.ipynb\">\n",
12
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
13
+ "</a>\n",
14
+ "\n",
15
+ "**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))\n",
16
+ "\n",
17
+ "**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using\n",
18
+ "[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).\n",
19
+ "\n",
20
+ "Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.\n",
21
+ "\n",
22
+ "**Steps**:\n",
23
+ "1. Set up environment\n",
24
+ "2. Load default Presto model\n",
25
+ "3. Transform Presto model into TorchScript\n",
26
+ "4. Package TorchScript model into TorchServe\n",
27
+ "5. Deploy and use Vertex AI\n",
28
+ "\n",
29
+ " 5a. Upload TorchServe model to Vertex AI Model Registry [Free]\n",
30
+ "\n",
31
+ " 5b. Create a Vertex AI Endpoint [Free]\n",
32
+ "\n",
33
+ " 5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]\n",
34
+ "\n",
35
+ " 5d. Generate embeddings in Google Earth Engine [Cost depends on region size]\n",
36
+ "\n",
37
+ " 5e. Undeploy model from endpoint [Free]\n",
38
+ "\n",
39
+ "**Cost Breakdown**:\n",
40
+ "\n",
41
+ "*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*\n",
42
+ "- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)\n",
43
+ "- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))\n",
44
+ "\n",
45
+ "*5b. Create a Vertex AI Endpoint [Free]*\n",
46
+ "- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint\n",
47
+ "\n",
48
+ "*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*\n",
49
+ "- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1)\n",
50
+ "- So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made\n",
51
+ "\n",
52
+ "*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*\n",
53
+ "- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)\n",
54
+ "- Our current embedding generation cost estimates are <strong>\\$5.37 - \\$10.14 per 1000 km<sup>2</sup> </strong>\n",
55
+ "- We compute a cost estimate for your ROI in our Google Earth Engine script\n",
56
+ "\n",
57
+ "*5e. Undeploy model from endpoint [Free]*\n",
58
+ "- Necessary to stop incurring charges from 5c"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {
64
+ "id": "hzb1bwgTUZU0"
65
+ },
66
+ "source": [
67
+ "## 1. Set up environment"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {
74
+ "id": "KuoEjld3TTLO"
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "from google.colab import auth\n",
79
+ "\n",
80
+ "auth.authenticate_user()"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "metadata": {
87
+ "id": "cUI_5pWJ3V4s"
88
+ },
89
+ "outputs": [],
90
+ "source": [
91
+ "PROJECT = '<YOUR CLOUD PROJECT>'\n",
92
+ "!gcloud config set project {PROJECT}"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "id": "MRGjYjltsm6-"
100
+ },
101
+ "outputs": [],
102
+ "source": [
103
+ "!git clone https://github.com/nasaharvest/presto.git"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "metadata": {
109
+ "id": "P1zGbf2KIhLA"
110
+ },
111
+ "source": [
112
+ "## 2. Load default Presto model"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {
119
+ "id": "UKgCxBNnYJIB"
120
+ },
121
+ "outputs": [],
122
+ "source": [
123
+ "# Navigate inside of the repository to import Presto\n",
124
+ "%cd /content/presto\n",
125
+ "\n",
126
+ "import torch\n",
127
+ "from single_file_presto import Presto\n",
128
+ "\n",
129
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
130
+ "\n",
131
+ "model = Presto.construct()\n",
132
+ "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n",
133
+ "model.eval();\n",
134
+ "\n",
135
+ "# Navigate back to main directory\n",
136
+ "%cd /content"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {
142
+ "id": "5v7gJoysFTsf"
143
+ },
144
+ "source": [
145
+ "## 3. Transform Presto model into TorchScript\n",
146
+ "> TorchScript is a way to create serializable and optimizable models from PyTorch code.\n",
147
+ "https://docs.pytorch.org/docs/stable/jit.html"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {
154
+ "id": "hDKqqzi9F7T4"
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "# Construct input manually\n",
159
+ "batch_size = 256\n",
160
+ "NUM_TIMESTEPS = 12\n",
161
+ "X_tensor = torch.zeros([batch_size, NUM_TIMESTEPS, 17])\n",
162
+ "latlons_tensor = torch.zeros([batch_size, 2])\n",
163
+ "\n",
164
+ "dw_empty = torch.full([batch_size, NUM_TIMESTEPS], 9, device=device).long()\n",
165
+ "month_tensor = torch.full([batch_size], 1, device=device)\n",
166
+ "\n",
167
+ "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n",
168
+ "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n",
169
+ "mask = torch.zeros(X_tensor.shape, device=device).float()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {
176
+ "id": "mdSXGZuikHUk"
177
+ },
178
+ "outputs": [],
179
+ "source": [
180
+ "# Verify forward pass with regular model\n",
181
+ "with torch.no_grad():\n",
182
+ " preds = model.encoder(\n",
183
+ " x=X_tensor,\n",
184
+ " dynamic_world=dw_empty,\n",
185
+ " latlons=latlons_tensor,\n",
186
+ " mask=mask,\n",
187
+ " month=month_tensor\n",
188
+ " )"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "metadata": {
195
+ "id": "nFRvUVkHowKr"
196
+ },
197
+ "outputs": [],
198
+ "source": [
199
+ "# Make model torchscriptable\n",
200
+ "example_kwargs = {\n",
201
+ " 'x': X_tensor,\n",
202
+ " 'dynamic_world': dw_empty,\n",
203
+ " 'latlons': latlons_tensor,\n",
204
+ " 'mask': mask,\n",
205
+ " 'month': month_tensor\n",
206
+ "}\n",
207
+ "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n",
208
+ "\n",
209
+ "!mkdir -p pytorch_model\n",
210
+ "sm.save('pytorch_model/model.pt')"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {
217
+ "id": "cYuSPOyp1A0K"
218
+ },
219
+ "outputs": [],
220
+ "source": [
221
+ "jit_model = torch.jit.load('pytorch_model/model.pt')\n",
222
+ "jit_model(**example_kwargs).shape"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "metadata": {
228
+ "id": "3b0LsZpqnByv"
229
+ },
230
+ "source": [
231
+ "## 4. Package TorchScript model into TorchServe\n",
232
+ "> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.\n",
233
+ "https://docs.pytorch.org/serve/"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {
240
+ "id": "i70o_BZml9vs"
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "!pip install torchserve torch-model-archiver -q"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {
251
+ "id": "htq2Ac95FJlk"
252
+ },
253
+ "outputs": [],
254
+ "source": [
255
+ "%%writefile pytorch_model/custom_handler.py\n",
256
+ "import logging\n",
257
+ "import torch\n",
258
+ "from ts.torch_handler.base_handler import BaseHandler\n",
259
+ "import numpy as np\n",
260
+ "\n",
261
+ "# UPDATE BASED ON YOUR NEEDS\n",
262
+ "########################################\n",
263
+ "VERSION = \"v1\"\n",
264
+ "START_MONTH = 3\n",
265
+ "BATCH_SIZE = 256\n",
266
+ "########################################\n",
267
+ "\n",
268
+ "def printh(text):\n",
269
+ " # Prepends HANDLER to each print statement to make it easier to find in logs.\n",
270
+ " print(f\"HANDLER {VERSION}: {text}\")\n",
271
+ "\n",
272
+ "# Custom TorchServe handler for the Presto model\n",
273
+ "class ClassifierHandler(BaseHandler):\n",
274
+ "\n",
275
+ " def inference(self, data):\n",
276
+ " printh(\"Inference begin\")\n",
277
+ "\n",
278
+ " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n",
279
+ " data = data[:, :, 0, 0]\n",
280
+ " printh(f\"Data shape {data.shape}\")\n",
281
+ "\n",
282
+ " num_bands = 17\n",
283
+ " printh(f\"Num_bands {num_bands}\")\n",
284
+ "\n",
285
+ " # Subtract first two latlon\n",
286
+ " num_timesteps = (data.shape[1] - 2) // num_bands\n",
287
+ " printh(f\"Num_timesteps {num_timesteps}\")\n",
288
+ "\n",
289
+ " with torch.no_grad():\n",
290
+ "\n",
291
+ " batches = torch.split(data, BATCH_SIZE, dim=0)\n",
292
+ "\n",
293
+ " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n",
294
+ " month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)\n",
295
+ " printh(f\"Month: {START_MONTH}\")\n",
296
+ "\n",
297
+ " # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n",
298
+ " dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()\n",
299
+ " printh(f\"DW {dw_empty[0]}\")\n",
300
+ "\n",
301
+ " # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n",
302
+ " mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()\n",
303
+ " printh(f\"Mask sample one timestep: {mask[0, 0]}\")\n",
304
+ "\n",
305
+ " preds_list = []\n",
306
+ " for batch in batches:\n",
307
+ " padding = 0\n",
308
+ " if batch.shape[0] < BATCH_SIZE:\n",
309
+ " padding = BATCH_SIZE - batch.shape[0]\n",
310
+ " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n",
311
+ "\n",
312
+ " # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.\n",
313
+ " X_tensor = batch[:, 2:]\n",
314
+ " printh(f\"X {X_tensor.shape}\")\n",
315
+ "\n",
316
+ " X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)\n",
317
+ " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n",
318
+ "\n",
319
+ " # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.\n",
320
+ " latlons_tensor = batch[:, :2]\n",
321
+ "\n",
322
+ " printh(\"SHAPES\")\n",
323
+ " printh(f\"X {X_tensor_reshaped.shape}\")\n",
324
+ " printh(f\"DW {dw_empty.shape}\")\n",
325
+ " printh(f\"Latlons {latlons_tensor.shape}\")\n",
326
+ " printh(f\"Mask {mask.shape}\")\n",
327
+ " printh(f\"Month {month_tensor.shape}\")\n",
328
+ "\n",
329
+ " pred = self.model(\n",
330
+ " x=X_tensor_reshaped,\n",
331
+ " dynamic_world=dw_empty,\n",
332
+ " latlons=latlons_tensor,\n",
333
+ " mask=mask,\n",
334
+ " month=month_tensor\n",
335
+ " )\n",
336
+ " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n",
337
+ " if padding == 0:\n",
338
+ " preds_list.append(pred_np[:])\n",
339
+ " else:\n",
340
+ " preds_list.append(pred_np[:-padding])\n",
341
+ "\n",
342
+ " [printh(f\"{p.shape}\") for p in preds_list]\n",
343
+ " preds = np.concatenate(preds_list)\n",
344
+ " printh(f\"Preds shape {preds.shape}\")\n",
345
+ " return preds\n",
346
+ "\n",
347
+ " def handle(self, data, context):\n",
348
+ " self.context = context\n",
349
+ " printh(f\"Handle begin\")\n",
350
+ " input_tensor = self.preprocess(data)\n",
351
+ " printh(f\"Input_tensor shape {input_tensor.shape}\")\n",
352
+ " pred_out = self.inference(input_tensor)\n",
353
+ " printh(f\"Inference complete\")\n",
354
+ " return self.postprocess(pred_out)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {
361
+ "id": "a3Dgq5Ob5b1i"
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "import importlib\n",
366
+ "import pytorch_model.custom_handler\n",
367
+ "\n",
368
+ "importlib.reload(pytorch_model.custom_handler)\n",
369
+ "\n",
370
+ "from pytorch_model.custom_handler import ClassifierHandler, VERSION\n",
371
+ "\n",
372
+ "# Test output\n",
373
+ "data = torch.zeros([713, 206, 1, 1])\n",
374
+ "handler = ClassifierHandler()\n",
375
+ "handler.model = jit_model\n",
376
+ "preds = handler.handle(data=data, context=None)"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {
383
+ "id": "90TXNAnfF-TD"
384
+ },
385
+ "outputs": [],
386
+ "source": [
387
+ "!torch-model-archiver -f \\\n",
388
+ " --model-name model \\\n",
389
+ " --version 1.0 \\\n",
390
+ " --serialized-file 'pytorch_model/model.pt' \\\n",
391
+ " --handler 'pytorch_model/custom_handler.py' \\\n",
392
+ " --export-path pytorch_model/"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "metadata": {
398
+ "id": "PYK9g3r1qVru"
399
+ },
400
+ "source": [
401
+ "## 5. Deploy and use Vertex AI"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "markdown",
406
+ "metadata": {
407
+ "id": "CH0JO_Jww5Ok"
408
+ },
409
+ "source": [
410
+ "### 5a. Upload TorchServe model to Vertex AI Model Registry\n",
411
+ "> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.\n",
412
+ "https://cloud.google.com/vertex-ai/docs/model-registry/introduction"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {
419
+ "id": "lEUqcAqsTYpn"
420
+ },
421
+ "outputs": [],
422
+ "source": [
423
+ "REGION = 'us-central1'\n",
424
+ "BUCKET_NAME = \"<YOUR CLOUD BUCKET>\""
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "metadata": {
431
+ "id": "OiyyCNa-TBGQ"
432
+ },
433
+ "outputs": [],
434
+ "source": [
435
+ "# Create bucket to store model artifcats if it doesn't exist\n",
436
+ "!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {
443
+ "id": "n8m9UBy3GEvZ"
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'\n",
448
+ "!gsutil cp -r pytorch_model {MODEL_DIR}"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": null,
454
+ "metadata": {
455
+ "id": "AetRF8dcGraC"
456
+ },
457
+ "outputs": [],
458
+ "source": [
459
+ "# Can take 2 minutes\n",
460
+ "MODEL_NAME = f'model_{VERSION}'\n",
461
+ "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n",
462
+ "\n",
463
+ "!gcloud ai models upload \\\n",
464
+ " --artifact-uri={MODEL_DIR} \\\n",
465
+ " --region={REGION} \\\n",
466
+ " --container-image-uri={CONTAINER_IMAGE} \\\n",
467
+ " --description={MODEL_NAME} \\\n",
468
+ " --display-name={MODEL_NAME} \\\n",
469
+ " --model-id={MODEL_NAME}"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "markdown",
474
+ "metadata": {
475
+ "id": "_BXDtoITxY2T"
476
+ },
477
+ "source": [
478
+ "### 5b. Create a Vertex AI Endpoint\n",
479
+ "> To deploy a model for online prediction, you need an endpoint.\n",
480
+ "https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type\n"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "metadata": {
487
+ "id": "XhMK9aA73FzI"
488
+ },
489
+ "outputs": [],
490
+ "source": [
491
+ "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n",
492
+ "\n",
493
+ "endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'\n",
494
+ "\n",
495
+ "if ENDPOINT_NAME in endpoints:\n",
496
+ " print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n",
497
+ "else:\n",
498
+ " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)\")\n",
499
+ " !gcloud ai endpoints create \\\n",
500
+ " --display-name={ENDPOINT_NAME} \\\n",
501
+ " --endpoint-id={ENDPOINT_NAME} \\\n",
502
+ " --region={REGION}"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "metadata": {
508
+ "id": "dWODP6ccx8-3"
509
+ },
510
+ "source": [
511
+ "### 5c. Deploy model to endpoint\n",
512
+ "> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n",
513
+ "https://cloud.google.com/vertex-ai/docs/general/deployment\n",
514
+ "\n",
515
+ "⚠️ The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1). So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made."
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "metadata": {
522
+ "id": "sI297v1mjtR8"
523
+ },
524
+ "outputs": [],
525
+ "source": [
526
+ "# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money\n",
527
+ "print(\"Track model deployment progress and prediction logs:\")\n",
528
+ "print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n",
529
+ "\n",
530
+ "# If using for large region, set min-replica-count higher to save scaling time\n",
531
+ "# Can take from 4-27 minutes\n",
532
+ "# Relevant quota: \"Custom model serving CPUs per region\"\n",
533
+ "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n",
534
+ " --region={REGION} \\\n",
535
+ " --model={MODEL_NAME} \\\n",
536
+ " --display-name={MODEL_NAME} \\\n",
537
+ " --machine-type=\"e2-standard-2\" \\\n",
538
+ " --min-replica-count='1' \\\n",
539
+ " --max-replica-count=\"100\""
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "markdown",
544
+ "metadata": {
545
+ "id": "r2VtGUv9JOI9"
546
+ },
547
+ "source": [
548
+ "### 5d. Generate embeddings in Google Earth Engine\n"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "metadata": {
555
+ "id": "160PNcRRJMzn"
556
+ },
557
+ "outputs": [],
558
+ "source": [
559
+ "GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n",
560
+ "print(f\"Open this script: {GEE_SCRIPT_URL}\")\n",
561
+ "print(\"Use the below string for the ENDPOINT variable\")\n",
562
+ "print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {
568
+ "id": "8PnhA3gfHSrY"
569
+ },
570
+ "source": [
571
+ "### 5e. Undeploy model from endpoint\n",
572
+ "\n",
573
+ "Once predictions are made, you must <strong>undeploy your model</strong> to stop incurring further charges.\n",
574
+ "\n",
575
+ "This can be done using the below code or by using the Google Cloud console directly."
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "execution_count": null,
581
+ "metadata": {
582
+ "id": "VvOO_sfQDWPt"
583
+ },
584
+ "outputs": [],
585
+ "source": [
586
+ "def get_deployed_model():\n",
587
+ " deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'\n",
588
+ " if deployed_models[1] == '':\n",
589
+ " print(\"No models deployed\")\n",
590
+ " else:\n",
591
+ " print(deployed_model_id)\n",
592
+ " return eval(deployed_models[1])['id']\n",
593
+ "\n",
594
+ "deployed_model_id = get_deployed_model()"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": null,
600
+ "metadata": {
601
+ "id": "vswFTu9kFeHy"
602
+ },
603
+ "outputs": [],
604
+ "source": [
605
+ "!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": null,
611
+ "metadata": {
612
+ "id": "w9Fq6yspF2ye"
613
+ },
614
+ "outputs": [],
615
+ "source": [
616
+ "get_deployed_model()"
617
+ ]
618
+ }
619
+ ],
620
+ "metadata": {
621
+ "colab": {
622
+ "provenance": []
623
+ },
624
+ "kernelspec": {
625
+ "display_name": "Python 3",
626
+ "name": "python3"
627
+ },
628
+ "language_info": {
629
+ "name": "python"
630
+ }
631
+ },
632
+ "nbformat": 4,
633
+ "nbformat_minor": 0
634
+ }
google_earth/1d196e8466506239c4780585c0e28d26/script.txt ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //------------------------------------------------------------------------------------
2
+ // Script for generating Presto embeddings using Vertex AI
3
+ // Author: Ivan Zvonkov (izvonkov@umd.edu)
4
+ //------------------------------------------------------------------------------------
5
+ // 1. Presto embedding generation parameters (set parameters according to your needs)
6
+ //------------------------------------------------------------------------------------
7
+ var roi = ee.FeatureCollection("FAO/GAUL/2015/level2").filter(ee.Filter.eq('ADM2_NAME', 'Haho'));
8
+ var PROJ = 'EPSG:25231'
9
+
10
+ var rangeStart = ee.Date('2019-03-01')
11
+ var rangeEnd = ee.Date('2020-03-01')
12
+
13
+ var ENDPOINT = 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'
14
+ var RUN_VERTEX_AI = false // Leave this as false to get a cost estimate first
15
+ //------------------------------------------------------------------------------------
16
+
17
+ Map.centerObject(roi, 10)
18
+ Map.addLayer(roi, {}, "Region of Interest")
19
+ Map.setOptions('satellite')
20
+
21
+ // 2. Cost Computation
22
+ var roiAreaKM2 = roi.geometry().area().divide(1e6)
23
+ function estimate(cost){return roiAreaKM2.divide(1000).multiply(cost).toInt().getInfo()}
24
+ print("ROI Area: " + roiAreaKM2.toInt().getInfo() + " km2")
25
+ print("Embedding Generation Estimates\nCost: $" + estimate(5.37) + "-" + estimate(10.14))
26
+ if (!RUN_VERTEX_AI)
27
+ print("If you are ready to generate embeddings,\nchange RUN_VERTEX_AI variable to true")
28
+
29
+
30
+ // 3. Obtain monthly Sentinel-1 composites
31
+ var S1_BANDS = ["VV", "VH"]
32
+ var S1_all = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(roi)
33
+ .filterDate(ee.Date(rangeStart).advance(-31, 'days'), ee.Date(rangeEnd).advance(31, 'days'))
34
+
35
+ var S1 = S1_all
36
+ .filter(ee.Filter.eq("orbitProperties_pass", S1_all.first().get("orbitProperties_pass")))
37
+ .filter(ee.Filter.eq("instrumentMode", "IW"))
38
+ var S1_VV = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
39
+ var S1_VH = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
40
+
41
+ function getCloseImages(middleDate, imageCollection){
42
+ var fromMiddleDate = imageCollection.map(function(img){
43
+ var dateDist = ee.Number(img.get("system:time_start")).subtract(middleDate.millis()).abs()
44
+ return img.set("dateDist", dateDist)
45
+ }).sort({property: "dateDist", ascending: true})
46
+ var fifteenDaysInMs = ee.Number(1296000000)
47
+ var maxDiff = ee.Number(fromMiddleDate.first().get("dateDist")).max(fifteenDaysInMs)
48
+ return fromMiddleDate.filterMetadata("dateDist", "not_greater_than", maxDiff)
49
+ }
50
+
51
+ function S1_img(date1, date2){
52
+ var startDate = ee.Date(date1)
53
+ var daysBetween = ee.Date(date2).difference(startDate, 'days')
54
+ var middleDate = startDate.advance(daysBetween.divide(2), 'days')
55
+ var kept_vv = getCloseImages(middleDate, S1_VV).select("VV")
56
+ var kept_vh = getCloseImages(middleDate, S1_VH).select("VH")
57
+ var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()])
58
+ return S1_composite.select(S1_BANDS).add(25.0).divide(25.0) // S1 ranges from -50 to 1
59
+ }
60
+
61
+
62
+ // 4. Obtain monthly Sentinel-2 composites
63
+ var S2_BANDS = ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B11", "B12"]
64
+ var S2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(roi).filterDate(rangeStart, rangeEnd)
65
+ var csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED').filterBounds(roi).filterDate(rangeStart, rangeEnd)
66
+ var QA_BAND = 'cs_cdf'; // Better than cs here
67
+ var S2_cf = S2.linkCollection(csPlus, [QA_BAND])
68
+
69
+ function S2_img(date1, date2){
70
+ return S2_cf.filterDate(date1, date2)
71
+ .qualityMosaic(QA_BAND)
72
+ .select(S2_BANDS)
73
+ .divide(ee.Image(1e4))
74
+ }
75
+
76
+
77
+ // 5. Obtain monthly ERA5 composites
78
+ var ERA5_BANDS = ["temperature_2m", "total_precipitation_sum"]
79
+ var ERA5 = ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR").filterBounds(roi).filterDate(rangeStart, rangeEnd)
80
+ function ERA5_img(date1, date2){
81
+ return ERA5.filterDate(date1, date2)
82
+ .select(ERA5_BANDS)
83
+ .mean()
84
+ .add([-272.15, 0]).divide([35, 0.03])
85
+ }
86
+ //var ERA5_temp = ee.Image([0,0]).rename(ERA5_BANDS).clip(roi)
87
+
88
+
89
+ // 6. Obtain SRTM Data
90
+ var SRTM_BANDS = ["elevation", "slope"]
91
+ var elevation = ee.Image("USGS/SRTMGL1_003").clip(roi).select("elevation")
92
+ var slope = ee.Terrain.slope(elevation)
93
+ var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50])
94
+ //var SRTM_temp = ee.Image([0,0]).rename(SRTM_BANDS).clip(roi)
95
+
96
+
97
+ // 7. Combine all data into a monthly CropHarvest-style monthly composite
98
+ function cropharvest_img(d1, d2){
99
+ var img = ee.Image.cat([S1_img(d1, d2), S2_img(d1, d2), ERA5_img(d1, d2), SRTM_img])
100
+ var ndvi = img.normalizedDifference(['B8', 'B4']).rename("NDVI")
101
+ // toFloat Necessary for tensor conversion
102
+ return img.addBands(ndvi).clip(roi).toFloat()
103
+ }
104
+
105
+
106
+ // 8. Create and visualize Presto input
107
+ var latlons = ee.Image.pixelLonLat().clip(roi).select("latitude", "longitude")
108
+ var imgs = [latlons]
109
+ var numMonths = rangeEnd.difference(rangeStart, 'month').toInt().getInfo()
110
+ var ERA5Palette = [
111
+ '000080', '0000d9', '4000ff', '8000ff', '0080ff', '00ffff', '00ff80', '80ff00', 'daff00',
112
+ 'ffff00', 'fff500', 'ffda00','ffb000', 'ffa400', 'ff4f00', 'ff2500', 'ff0a00', 'ff00ff'
113
+ ]
114
+
115
+ for (var i = 0; i < numMonths; i++){
116
+ var monthStart = rangeStart.advance(i, 'month')
117
+ var monthEnd = monthStart.advance(1, 'month')
118
+ var img = cropharvest_img(monthStart, monthEnd)
119
+ imgs.push(img)
120
+
121
+ var monthName = monthStart.format("YY/MM").getInfo()
122
+ Map.addLayer(img, {bands: ["VV", "VH", "VV"], min: [0, -0.2, 0.4], max: [1.0, 0.8, 1.2]}, monthName + " S1", false)
123
+ Map.addLayer(img, {bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }, monthName + " S2", false)
124
+ Map.addLayer(img, {bands: ['temperature_2m'], min: 0, max: 1, palette: ERA5Palette}, monthName + " ERA5", false)
125
+ }
126
+ Map.addLayer(imgs[1], {bands: ["slope"], min: 0, max: 0.3 }, "SRTM", false)
127
+
128
+ var composite = ee.ImageCollection.fromImages(imgs).toBands()
129
+
130
+ // 9. Make predictions using Presto on Vertex AI
131
+ var vertex_model = ee.Model.fromVertexAi({
132
+ endpoint: ENDPOINT,
133
+ inputTileSize: [1,1],
134
+ proj: ee.Projection('EPSG:4326').atScale(10),
135
+ fixInputProj: true,
136
+ outputTileSize: [1,1],
137
+ outputBands: {'p': { 'type': ee.PixelType.float(), 'dimensions': 1}},
138
+ payloadFormat: 'ND_ARRAYS',
139
+ maxPayloadBytes: 5242880 // 5.24mb [MAX]
140
+ })
141
+
142
+ if (RUN_VERTEX_AI){
143
+
144
+ // Create band names for embeddingsArrayImage
145
+ var bandNames = []
146
+ for (var i=0; i<128; i++){ bandNames.push("b" + i + "") }
147
+
148
+ // embeddingsArrayImage is a single band image where each pixel contains an array
149
+ var embeddingsArrayImage = vertex_model.predictImage(composite).clip(roi)
150
+ var embeddingsMultiBandImage = embeddingsArrayImage.arrayFlatten([bandNames])
151
+
152
+ // Only smaller size embeddings can be directly viewed in GEE immediatley larger ones require the batch task
153
+ // Map.addLayer(embeddingsMultiBandImage, {min: 0, max: 1},'embeddingsMultiBandImage')
154
+
155
+ Export.image.toAsset({
156
+ image: embeddingsMultiBandImage,
157
+ description: 'Presto_embeddings',
158
+ assetId: 'Togo/Presto_test_embeddings_v2025_04_23',
159
+ region: roi,
160
+ scale: 10,
161
+ maxPixels: 1e12,
162
+ crs: 'EPSG:25231'
163
+ });
164
+ }
165
+
166
+
google_earth/1d196e8466506239c4780585c0e28d26/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26