Presto (code, colab, google_earth, paper)
Browse files- .gitattributes +3 -0
- 20240418-Presto-TFWorkingGroup.pptx +3 -0
- Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.pdf +3 -0
- code/LEM Data Processing.pptx +3 -0
- code/presto [DurojaiyeAbisoye] +6 -5.zip +3 -0
- code/presto-ndws.zip +3 -0
- code/presto-worldcereal.zip +3 -0
- code/presto.zip +3 -0
- colab/1_Presto_to_VertexAI.ipynb +634 -0
- google_earth/1d196e8466506239c4780585c0e28d26/script.txt +166 -0
- google_earth/1d196e8466506239c4780585c0e28d26/source.txt +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
20240418-Presto-TFWorkingGroup.pptx filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
code/LEM[[:space:]]Data[[:space:]]Processing.pptx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Lightweight,[[:space:]]Pre-trained[[:space:]]Transformers[[:space:]]for[[:space:]]Remote[[:space:]]Sensing[[:space:]]Timeseries.pdf filter=lfs diff=lfs merge=lfs -text
|
20240418-Presto-TFWorkingGroup.pptx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7ab6859bee5653ed738b21cff701605c71768e453e2c6c81ac6c7609dbf4d39
|
| 3 |
+
size 21067278
|
Lightweight, Pre-trained Transformers for Remote Sensing Timeseries.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8324993896a3c346def6a5a0f923561d44f1c3aed270ff19bb29ac45986128f2
|
| 3 |
+
size 3957097
|
code/LEM Data Processing.pptx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22e30405ed3bb718599afa342ccb857a41d64d7fb58c2bd582b45078b2b65b87
|
| 3 |
+
size 1296790
|
code/presto [DurojaiyeAbisoye] +6 -5.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5015cb3128495d8c40b2d77a1ac5d93060469a3d6ff662cd57e4ee0668ac649
|
| 3 |
+
size 35240122
|
code/presto-ndws.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ad48fd3638c91ad7b6856cd953dcba8065d98ba9a940de14784b9f654be0d0e
|
| 3 |
+
size 263605
|
code/presto-worldcereal.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:466dedf973084219798dcc0b5a4f24241f39bd7af20a6d40acc838275435c4ec
|
| 3 |
+
size 71758689
|
code/presto.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d299c366a348393c54054f3b7d4994cd0ce3facafcd9cbf4b1adb6f26095fb7
|
| 3 |
+
size 45533344
|
colab/1_Presto_to_VertexAI.ipynb
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "_SVA9v_JTq_-"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# 1. Presto to Vertex AI\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/nasaharvest/presto/blob/main/deploy/1_Presto_to_VertexAI.ipynb\">\n",
|
| 12 |
+
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
|
| 13 |
+
"</a>\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using\n",
|
| 18 |
+
"[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"**Steps**:\n",
|
| 23 |
+
"1. Set up environment\n",
|
| 24 |
+
"2. Load default Presto model\n",
|
| 25 |
+
"3. Transform Presto model into TorchScript\n",
|
| 26 |
+
"4. Package TorchScript model into TorchServe\n",
|
| 27 |
+
"5. Deploy and use Vertex AI\n",
|
| 28 |
+
"\n",
|
| 29 |
+
" 5a. Upload TorchServe model to Vertex AI Model Registry [Free]\n",
|
| 30 |
+
"\n",
|
| 31 |
+
" 5b. Create a Vertex AI Endpoint [Free]\n",
|
| 32 |
+
"\n",
|
| 33 |
+
" 5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]\n",
|
| 34 |
+
"\n",
|
| 35 |
+
" 5d. Generate embeddings in Google Earth Engine [Cost depends on region size]\n",
|
| 36 |
+
"\n",
|
| 37 |
+
" 5e. Undeploy model from endpoint [Free]\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"**Cost Breakdown**:\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*\n",
|
| 42 |
+
"- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)\n",
|
| 43 |
+
"- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"*5b. Create a Vertex AI Endpoint [Free]*\n",
|
| 46 |
+
"- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*\n",
|
| 49 |
+
"- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1)\n",
|
| 50 |
+
"- So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*\n",
|
| 53 |
+
"- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)\n",
|
| 54 |
+
"- Our current embedding generation cost estimates are <strong>\\$5.37 - \\$10.14 per 1000 km<sup>2</sup> </strong>\n",
|
| 55 |
+
"- We compute a cost estimate for your ROI in our Google Earth Engine script\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"*5e. Undeploy model from endpoint [Free]*\n",
|
| 58 |
+
"- Necessary to stop incurring charges from 5c"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "markdown",
|
| 63 |
+
"metadata": {
|
| 64 |
+
"id": "hzb1bwgTUZU0"
|
| 65 |
+
},
|
| 66 |
+
"source": [
|
| 67 |
+
"## 1. Set up environment"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"metadata": {
|
| 74 |
+
"id": "KuoEjld3TTLO"
|
| 75 |
+
},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"from google.colab import auth\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"auth.authenticate_user()"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"metadata": {
|
| 87 |
+
"id": "cUI_5pWJ3V4s"
|
| 88 |
+
},
|
| 89 |
+
"outputs": [],
|
| 90 |
+
"source": [
|
| 91 |
+
"PROJECT = '<YOUR CLOUD PROJECT>'\n",
|
| 92 |
+
"!gcloud config set project {PROJECT}"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": null,
|
| 98 |
+
"metadata": {
|
| 99 |
+
"id": "MRGjYjltsm6-"
|
| 100 |
+
},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"!git clone https://github.com/nasaharvest/presto.git"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "markdown",
|
| 108 |
+
"metadata": {
|
| 109 |
+
"id": "P1zGbf2KIhLA"
|
| 110 |
+
},
|
| 111 |
+
"source": [
|
| 112 |
+
"## 2. Load default Presto model"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"metadata": {
|
| 119 |
+
"id": "UKgCxBNnYJIB"
|
| 120 |
+
},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": [
|
| 123 |
+
"# Navigate inside of the repository to import Presto\n",
|
| 124 |
+
"%cd /content/presto\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"import torch\n",
|
| 127 |
+
"from single_file_presto import Presto\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"model = Presto.construct()\n",
|
| 132 |
+
"model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n",
|
| 133 |
+
"model.eval();\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# Navigate back to main directory\n",
|
| 136 |
+
"%cd /content"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"metadata": {
|
| 142 |
+
"id": "5v7gJoysFTsf"
|
| 143 |
+
},
|
| 144 |
+
"source": [
|
| 145 |
+
"## 3. Transform Presto model into TorchScript\n",
|
| 146 |
+
"> TorchScript is a way to create serializable and optimizable models from PyTorch code.\n",
|
| 147 |
+
"https://docs.pytorch.org/docs/stable/jit.html"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"cell_type": "code",
|
| 152 |
+
"execution_count": null,
|
| 153 |
+
"metadata": {
|
| 154 |
+
"id": "hDKqqzi9F7T4"
|
| 155 |
+
},
|
| 156 |
+
"outputs": [],
|
| 157 |
+
"source": [
|
| 158 |
+
"# Construct input manually\n",
|
| 159 |
+
"batch_size = 256\n",
|
| 160 |
+
"NUM_TIMESTEPS = 12\n",
|
| 161 |
+
"X_tensor = torch.zeros([batch_size, NUM_TIMESTEPS, 17])\n",
|
| 162 |
+
"latlons_tensor = torch.zeros([batch_size, 2])\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"dw_empty = torch.full([batch_size, NUM_TIMESTEPS], 9, device=device).long()\n",
|
| 165 |
+
"month_tensor = torch.full([batch_size], 1, device=device)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n",
|
| 168 |
+
"# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n",
|
| 169 |
+
"mask = torch.zeros(X_tensor.shape, device=device).float()"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": null,
|
| 175 |
+
"metadata": {
|
| 176 |
+
"id": "mdSXGZuikHUk"
|
| 177 |
+
},
|
| 178 |
+
"outputs": [],
|
| 179 |
+
"source": [
|
| 180 |
+
"# Verify forward pass with regular model\n",
|
| 181 |
+
"with torch.no_grad():\n",
|
| 182 |
+
" preds = model.encoder(\n",
|
| 183 |
+
" x=X_tensor,\n",
|
| 184 |
+
" dynamic_world=dw_empty,\n",
|
| 185 |
+
" latlons=latlons_tensor,\n",
|
| 186 |
+
" mask=mask,\n",
|
| 187 |
+
" month=month_tensor\n",
|
| 188 |
+
" )"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "code",
|
| 193 |
+
"execution_count": null,
|
| 194 |
+
"metadata": {
|
| 195 |
+
"id": "nFRvUVkHowKr"
|
| 196 |
+
},
|
| 197 |
+
"outputs": [],
|
| 198 |
+
"source": [
|
| 199 |
+
"# Make model torchscriptable\n",
|
| 200 |
+
"example_kwargs = {\n",
|
| 201 |
+
" 'x': X_tensor,\n",
|
| 202 |
+
" 'dynamic_world': dw_empty,\n",
|
| 203 |
+
" 'latlons': latlons_tensor,\n",
|
| 204 |
+
" 'mask': mask,\n",
|
| 205 |
+
" 'month': month_tensor\n",
|
| 206 |
+
"}\n",
|
| 207 |
+
"sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"!mkdir -p pytorch_model\n",
|
| 210 |
+
"sm.save('pytorch_model/model.pt')"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "code",
|
| 215 |
+
"execution_count": null,
|
| 216 |
+
"metadata": {
|
| 217 |
+
"id": "cYuSPOyp1A0K"
|
| 218 |
+
},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"jit_model = torch.jit.load('pytorch_model/model.pt')\n",
|
| 222 |
+
"jit_model(**example_kwargs).shape"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"cell_type": "markdown",
|
| 227 |
+
"metadata": {
|
| 228 |
+
"id": "3b0LsZpqnByv"
|
| 229 |
+
},
|
| 230 |
+
"source": [
|
| 231 |
+
"## 4. Package TorchScript model into TorchServe\n",
|
| 232 |
+
"> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.\n",
|
| 233 |
+
"https://docs.pytorch.org/serve/"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "code",
|
| 238 |
+
"execution_count": null,
|
| 239 |
+
"metadata": {
|
| 240 |
+
"id": "i70o_BZml9vs"
|
| 241 |
+
},
|
| 242 |
+
"outputs": [],
|
| 243 |
+
"source": [
|
| 244 |
+
"!pip install torchserve torch-model-archiver -q"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"cell_type": "code",
|
| 249 |
+
"execution_count": null,
|
| 250 |
+
"metadata": {
|
| 251 |
+
"id": "htq2Ac95FJlk"
|
| 252 |
+
},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"%%writefile pytorch_model/custom_handler.py\n",
|
| 256 |
+
"import logging\n",
|
| 257 |
+
"import torch\n",
|
| 258 |
+
"from ts.torch_handler.base_handler import BaseHandler\n",
|
| 259 |
+
"import numpy as np\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"# UPDATE BASED ON YOUR NEEDS\n",
|
| 262 |
+
"########################################\n",
|
| 263 |
+
"VERSION = \"v1\"\n",
|
| 264 |
+
"START_MONTH = 3\n",
|
| 265 |
+
"BATCH_SIZE = 256\n",
|
| 266 |
+
"########################################\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"def printh(text):\n",
|
| 269 |
+
" # Prepends HANDLER to each print statement to make it easier to find in logs.\n",
|
| 270 |
+
" print(f\"HANDLER {VERSION}: {text}\")\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"# Custom TorchServe handler for the Presto model\n",
|
| 273 |
+
"class ClassifierHandler(BaseHandler):\n",
|
| 274 |
+
"\n",
|
| 275 |
+
" def inference(self, data):\n",
|
| 276 |
+
" printh(\"Inference begin\")\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n",
|
| 279 |
+
" data = data[:, :, 0, 0]\n",
|
| 280 |
+
" printh(f\"Data shape {data.shape}\")\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" num_bands = 17\n",
|
| 283 |
+
" printh(f\"Num_bands {num_bands}\")\n",
|
| 284 |
+
"\n",
|
| 285 |
+
" # Subtract first two latlon\n",
|
| 286 |
+
" num_timesteps = (data.shape[1] - 2) // num_bands\n",
|
| 287 |
+
" printh(f\"Num_timesteps {num_timesteps}\")\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" with torch.no_grad():\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" batches = torch.split(data, BATCH_SIZE, dim=0)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n",
|
| 294 |
+
" month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)\n",
|
| 295 |
+
" printh(f\"Month: {START_MONTH}\")\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n",
|
| 298 |
+
" dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()\n",
|
| 299 |
+
" printh(f\"DW {dw_empty[0]}\")\n",
|
| 300 |
+
"\n",
|
| 301 |
+
" # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n",
|
| 302 |
+
" mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()\n",
|
| 303 |
+
" printh(f\"Mask sample one timestep: {mask[0, 0]}\")\n",
|
| 304 |
+
"\n",
|
| 305 |
+
" preds_list = []\n",
|
| 306 |
+
" for batch in batches:\n",
|
| 307 |
+
" padding = 0\n",
|
| 308 |
+
" if batch.shape[0] < BATCH_SIZE:\n",
|
| 309 |
+
" padding = BATCH_SIZE - batch.shape[0]\n",
|
| 310 |
+
" batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n",
|
| 311 |
+
"\n",
|
| 312 |
+
" # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.\n",
|
| 313 |
+
" X_tensor = batch[:, 2:]\n",
|
| 314 |
+
" printh(f\"X {X_tensor.shape}\")\n",
|
| 315 |
+
"\n",
|
| 316 |
+
" X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)\n",
|
| 317 |
+
" printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.\n",
|
| 320 |
+
" latlons_tensor = batch[:, :2]\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" printh(\"SHAPES\")\n",
|
| 323 |
+
" printh(f\"X {X_tensor_reshaped.shape}\")\n",
|
| 324 |
+
" printh(f\"DW {dw_empty.shape}\")\n",
|
| 325 |
+
" printh(f\"Latlons {latlons_tensor.shape}\")\n",
|
| 326 |
+
" printh(f\"Mask {mask.shape}\")\n",
|
| 327 |
+
" printh(f\"Month {month_tensor.shape}\")\n",
|
| 328 |
+
"\n",
|
| 329 |
+
" pred = self.model(\n",
|
| 330 |
+
" x=X_tensor_reshaped,\n",
|
| 331 |
+
" dynamic_world=dw_empty,\n",
|
| 332 |
+
" latlons=latlons_tensor,\n",
|
| 333 |
+
" mask=mask,\n",
|
| 334 |
+
" month=month_tensor\n",
|
| 335 |
+
" )\n",
|
| 336 |
+
" pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n",
|
| 337 |
+
" if padding == 0:\n",
|
| 338 |
+
" preds_list.append(pred_np[:])\n",
|
| 339 |
+
" else:\n",
|
| 340 |
+
" preds_list.append(pred_np[:-padding])\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" [printh(f\"{p.shape}\") for p in preds_list]\n",
|
| 343 |
+
" preds = np.concatenate(preds_list)\n",
|
| 344 |
+
" printh(f\"Preds shape {preds.shape}\")\n",
|
| 345 |
+
" return preds\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" def handle(self, data, context):\n",
|
| 348 |
+
" self.context = context\n",
|
| 349 |
+
" printh(f\"Handle begin\")\n",
|
| 350 |
+
" input_tensor = self.preprocess(data)\n",
|
| 351 |
+
" printh(f\"Input_tensor shape {input_tensor.shape}\")\n",
|
| 352 |
+
" pred_out = self.inference(input_tensor)\n",
|
| 353 |
+
" printh(f\"Inference complete\")\n",
|
| 354 |
+
" return self.postprocess(pred_out)"
|
| 355 |
+
]
|
| 356 |
+
},
|
| 357 |
+
{
|
| 358 |
+
"cell_type": "code",
|
| 359 |
+
"execution_count": null,
|
| 360 |
+
"metadata": {
|
| 361 |
+
"id": "a3Dgq5Ob5b1i"
|
| 362 |
+
},
|
| 363 |
+
"outputs": [],
|
| 364 |
+
"source": [
|
| 365 |
+
"import importlib\n",
|
| 366 |
+
"import pytorch_model.custom_handler\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"importlib.reload(pytorch_model.custom_handler)\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"from pytorch_model.custom_handler import ClassifierHandler, VERSION\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"# Test output\n",
|
| 373 |
+
"data = torch.zeros([713, 206, 1, 1])\n",
|
| 374 |
+
"handler = ClassifierHandler()\n",
|
| 375 |
+
"handler.model = jit_model\n",
|
| 376 |
+
"preds = handler.handle(data=data, context=None)"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
+
"metadata": {
|
| 383 |
+
"id": "90TXNAnfF-TD"
|
| 384 |
+
},
|
| 385 |
+
"outputs": [],
|
| 386 |
+
"source": [
|
| 387 |
+
"!torch-model-archiver -f \\\n",
|
| 388 |
+
" --model-name model \\\n",
|
| 389 |
+
" --version 1.0 \\\n",
|
| 390 |
+
" --serialized-file 'pytorch_model/model.pt' \\\n",
|
| 391 |
+
" --handler 'pytorch_model/custom_handler.py' \\\n",
|
| 392 |
+
" --export-path pytorch_model/"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "markdown",
|
| 397 |
+
"metadata": {
|
| 398 |
+
"id": "PYK9g3r1qVru"
|
| 399 |
+
},
|
| 400 |
+
"source": [
|
| 401 |
+
"## 5. Deploy and use Vertex AI"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "markdown",
|
| 406 |
+
"metadata": {
|
| 407 |
+
"id": "CH0JO_Jww5Ok"
|
| 408 |
+
},
|
| 409 |
+
"source": [
|
| 410 |
+
"### 5a. Upload TorchServe model to Vertex AI Model Registry\n",
|
| 411 |
+
"> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.\n",
|
| 412 |
+
"https://cloud.google.com/vertex-ai/docs/model-registry/introduction"
|
| 413 |
+
]
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"cell_type": "code",
|
| 417 |
+
"execution_count": null,
|
| 418 |
+
"metadata": {
|
| 419 |
+
"id": "lEUqcAqsTYpn"
|
| 420 |
+
},
|
| 421 |
+
"outputs": [],
|
| 422 |
+
"source": [
|
| 423 |
+
"REGION = 'us-central1'\n",
|
| 424 |
+
"BUCKET_NAME = \"<YOUR CLOUD BUCKET>\""
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "code",
|
| 429 |
+
"execution_count": null,
|
| 430 |
+
"metadata": {
|
| 431 |
+
"id": "OiyyCNa-TBGQ"
|
| 432 |
+
},
|
| 433 |
+
"outputs": [],
|
| 434 |
+
"source": [
|
| 435 |
+
"# Create bucket to store model artifcats if it doesn't exist\n",
|
| 436 |
+
"!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}"
|
| 437 |
+
]
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"cell_type": "code",
|
| 441 |
+
"execution_count": null,
|
| 442 |
+
"metadata": {
|
| 443 |
+
"id": "n8m9UBy3GEvZ"
|
| 444 |
+
},
|
| 445 |
+
"outputs": [],
|
| 446 |
+
"source": [
|
| 447 |
+
"MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'\n",
|
| 448 |
+
"!gsutil cp -r pytorch_model {MODEL_DIR}"
|
| 449 |
+
]
|
| 450 |
+
},
|
| 451 |
+
{
|
| 452 |
+
"cell_type": "code",
|
| 453 |
+
"execution_count": null,
|
| 454 |
+
"metadata": {
|
| 455 |
+
"id": "AetRF8dcGraC"
|
| 456 |
+
},
|
| 457 |
+
"outputs": [],
|
| 458 |
+
"source": [
|
| 459 |
+
"# Can take 2 minutes\n",
|
| 460 |
+
"MODEL_NAME = f'model_{VERSION}'\n",
|
| 461 |
+
"CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n",
|
| 462 |
+
"\n",
|
| 463 |
+
"!gcloud ai models upload \\\n",
|
| 464 |
+
" --artifact-uri={MODEL_DIR} \\\n",
|
| 465 |
+
" --region={REGION} \\\n",
|
| 466 |
+
" --container-image-uri={CONTAINER_IMAGE} \\\n",
|
| 467 |
+
" --description={MODEL_NAME} \\\n",
|
| 468 |
+
" --display-name={MODEL_NAME} \\\n",
|
| 469 |
+
" --model-id={MODEL_NAME}"
|
| 470 |
+
]
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"cell_type": "markdown",
|
| 474 |
+
"metadata": {
|
| 475 |
+
"id": "_BXDtoITxY2T"
|
| 476 |
+
},
|
| 477 |
+
"source": [
|
| 478 |
+
"### 5b. Create a Vertex AI Endpoint\n",
|
| 479 |
+
"> To deploy a model for online prediction, you need an endpoint.\n",
|
| 480 |
+
"https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type\n"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"cell_type": "code",
|
| 485 |
+
"execution_count": null,
|
| 486 |
+
"metadata": {
|
| 487 |
+
"id": "XhMK9aA73FzI"
|
| 488 |
+
},
|
| 489 |
+
"outputs": [],
|
| 490 |
+
"source": [
|
| 491 |
+
"ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"if ENDPOINT_NAME in endpoints:\n",
|
| 496 |
+
" print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n",
|
| 497 |
+
"else:\n",
|
| 498 |
+
" print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)\")\n",
|
| 499 |
+
" !gcloud ai endpoints create \\\n",
|
| 500 |
+
" --display-name={ENDPOINT_NAME} \\\n",
|
| 501 |
+
" --endpoint-id={ENDPOINT_NAME} \\\n",
|
| 502 |
+
" --region={REGION}"
|
| 503 |
+
]
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"cell_type": "markdown",
|
| 507 |
+
"metadata": {
|
| 508 |
+
"id": "dWODP6ccx8-3"
|
| 509 |
+
},
|
| 510 |
+
"source": [
|
| 511 |
+
"### 5c. Deploy model to endpoint\n",
|
| 512 |
+
"> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n",
|
| 513 |
+
"https://cloud.google.com/vertex-ai/docs/general/deployment\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"⚠️ The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1). So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made."
|
| 516 |
+
]
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"cell_type": "code",
|
| 520 |
+
"execution_count": null,
|
| 521 |
+
"metadata": {
|
| 522 |
+
"id": "sI297v1mjtR8"
|
| 523 |
+
},
|
| 524 |
+
"outputs": [],
|
| 525 |
+
"source": [
|
| 526 |
+
"# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money\n",
|
| 527 |
+
"print(\"Track model deployment progress and prediction logs:\")\n",
|
| 528 |
+
"print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n",
|
| 529 |
+
"\n",
|
| 530 |
+
"# If using for large region, set min-replica-count higher to save scaling time\n",
|
| 531 |
+
"# Can take from 4-27 minutes\n",
|
| 532 |
+
"# Relevant quota: \"Custom model serving CPUs per region\"\n",
|
| 533 |
+
"!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n",
|
| 534 |
+
" --region={REGION} \\\n",
|
| 535 |
+
" --model={MODEL_NAME} \\\n",
|
| 536 |
+
" --display-name={MODEL_NAME} \\\n",
|
| 537 |
+
" --machine-type=\"e2-standard-2\" \\\n",
|
| 538 |
+
" --min-replica-count='1' \\\n",
|
| 539 |
+
" --max-replica-count=\"100\""
|
| 540 |
+
]
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"cell_type": "markdown",
|
| 544 |
+
"metadata": {
|
| 545 |
+
"id": "r2VtGUv9JOI9"
|
| 546 |
+
},
|
| 547 |
+
"source": [
|
| 548 |
+
"### 5d. Generate embeddings in Google Earth Engine\n"
|
| 549 |
+
]
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"cell_type": "code",
|
| 553 |
+
"execution_count": null,
|
| 554 |
+
"metadata": {
|
| 555 |
+
"id": "160PNcRRJMzn"
|
| 556 |
+
},
|
| 557 |
+
"outputs": [],
|
| 558 |
+
"source": [
|
| 559 |
+
"GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n",
|
| 560 |
+
"print(f\"Open this script: {GEE_SCRIPT_URL}\")\n",
|
| 561 |
+
"print(\"Use the below string for the ENDPOINT variable\")\n",
|
| 562 |
+
"print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")"
|
| 563 |
+
]
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"cell_type": "markdown",
|
| 567 |
+
"metadata": {
|
| 568 |
+
"id": "8PnhA3gfHSrY"
|
| 569 |
+
},
|
| 570 |
+
"source": [
|
| 571 |
+
"### 5e. Undeploy model from endpoint\n",
|
| 572 |
+
"\n",
|
| 573 |
+
"Once predictions are made, you must <strong>undeploy your model</strong> to stop incurring further charges.\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"This can be done using the below code or by using the Google Cloud console directly."
|
| 576 |
+
]
|
| 577 |
+
},
|
| 578 |
+
{
|
| 579 |
+
"cell_type": "code",
|
| 580 |
+
"execution_count": null,
|
| 581 |
+
"metadata": {
|
| 582 |
+
"id": "VvOO_sfQDWPt"
|
| 583 |
+
},
|
| 584 |
+
"outputs": [],
|
| 585 |
+
"source": [
|
| 586 |
+
"def get_deployed_model():\n",
|
| 587 |
+
" deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'\n",
|
| 588 |
+
" if deployed_models[1] == '':\n",
|
| 589 |
+
" print(\"No models deployed\")\n",
|
| 590 |
+
" else:\n",
|
| 591 |
+
" print(deployed_model_id)\n",
|
| 592 |
+
" return eval(deployed_models[1])['id']\n",
|
| 593 |
+
"\n",
|
| 594 |
+
"deployed_model_id = get_deployed_model()"
|
| 595 |
+
]
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"cell_type": "code",
|
| 599 |
+
"execution_count": null,
|
| 600 |
+
"metadata": {
|
| 601 |
+
"id": "vswFTu9kFeHy"
|
| 602 |
+
},
|
| 603 |
+
"outputs": [],
|
| 604 |
+
"source": [
|
| 605 |
+
"!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}"
|
| 606 |
+
]
|
| 607 |
+
},
|
| 608 |
+
{
|
| 609 |
+
"cell_type": "code",
|
| 610 |
+
"execution_count": null,
|
| 611 |
+
"metadata": {
|
| 612 |
+
"id": "w9Fq6yspF2ye"
|
| 613 |
+
},
|
| 614 |
+
"outputs": [],
|
| 615 |
+
"source": [
|
| 616 |
+
"get_deployed_model()"
|
| 617 |
+
]
|
| 618 |
+
}
|
| 619 |
+
],
|
| 620 |
+
"metadata": {
|
| 621 |
+
"colab": {
|
| 622 |
+
"provenance": []
|
| 623 |
+
},
|
| 624 |
+
"kernelspec": {
|
| 625 |
+
"display_name": "Python 3",
|
| 626 |
+
"name": "python3"
|
| 627 |
+
},
|
| 628 |
+
"language_info": {
|
| 629 |
+
"name": "python"
|
| 630 |
+
}
|
| 631 |
+
},
|
| 632 |
+
"nbformat": 4,
|
| 633 |
+
"nbformat_minor": 0
|
| 634 |
+
}
|
google_earth/1d196e8466506239c4780585c0e28d26/script.txt
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//------------------------------------------------------------------------------------
|
| 2 |
+
// Script for generating Presto embeddings using Vertex AI
|
| 3 |
+
// Author: Ivan Zvonkov (izvonkov@umd.edu)
|
| 4 |
+
//------------------------------------------------------------------------------------
|
| 5 |
+
// 1. Presto embedding generation parameters (set parameters according to your needs)
|
| 6 |
+
//------------------------------------------------------------------------------------
|
| 7 |
+
var roi = ee.FeatureCollection("FAO/GAUL/2015/level2").filter(ee.Filter.eq('ADM2_NAME', 'Haho'));
|
| 8 |
+
var PROJ = 'EPSG:25231'
|
| 9 |
+
|
| 10 |
+
var rangeStart = ee.Date('2019-03-01')
|
| 11 |
+
var rangeEnd = ee.Date('2020-03-01')
|
| 12 |
+
|
| 13 |
+
var ENDPOINT = 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'
|
| 14 |
+
var RUN_VERTEX_AI = false // Leave this as false to get a cost estimate first
|
| 15 |
+
//------------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
Map.centerObject(roi, 10)
|
| 18 |
+
Map.addLayer(roi, {}, "Region of Interest")
|
| 19 |
+
Map.setOptions('satellite')
|
| 20 |
+
|
| 21 |
+
// 2. Cost Computation
|
| 22 |
+
var roiAreaKM2 = roi.geometry().area().divide(1e6)
|
| 23 |
+
function estimate(cost){return roiAreaKM2.divide(1000).multiply(cost).toInt().getInfo()}
|
| 24 |
+
print("ROI Area: " + roiAreaKM2.toInt().getInfo() + " km2")
|
| 25 |
+
print("Embedding Generation Estimates\nCost: $" + estimate(5.37) + "-" + estimate(10.14))
|
| 26 |
+
if (!RUN_VERTEX_AI)
|
| 27 |
+
print("If you are ready to generate embeddings,\nchange RUN_VERTEX_AI variable to true")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
// 3. Obtain monthly Sentinel-1 composites
|
| 31 |
+
var S1_BANDS = ["VV", "VH"]
|
| 32 |
+
var S1_all = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(roi)
|
| 33 |
+
.filterDate(ee.Date(rangeStart).advance(-31, 'days'), ee.Date(rangeEnd).advance(31, 'days'))
|
| 34 |
+
|
| 35 |
+
var S1 = S1_all
|
| 36 |
+
.filter(ee.Filter.eq("orbitProperties_pass", S1_all.first().get("orbitProperties_pass")))
|
| 37 |
+
.filter(ee.Filter.eq("instrumentMode", "IW"))
|
| 38 |
+
var S1_VV = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
|
| 39 |
+
var S1_VH = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
|
| 40 |
+
|
| 41 |
+
function getCloseImages(middleDate, imageCollection){
|
| 42 |
+
var fromMiddleDate = imageCollection.map(function(img){
|
| 43 |
+
var dateDist = ee.Number(img.get("system:time_start")).subtract(middleDate.millis()).abs()
|
| 44 |
+
return img.set("dateDist", dateDist)
|
| 45 |
+
}).sort({property: "dateDist", ascending: true})
|
| 46 |
+
var fifteenDaysInMs = ee.Number(1296000000)
|
| 47 |
+
var maxDiff = ee.Number(fromMiddleDate.first().get("dateDist")).max(fifteenDaysInMs)
|
| 48 |
+
return fromMiddleDate.filterMetadata("dateDist", "not_greater_than", maxDiff)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
function S1_img(date1, date2){
|
| 52 |
+
var startDate = ee.Date(date1)
|
| 53 |
+
var daysBetween = ee.Date(date2).difference(startDate, 'days')
|
| 54 |
+
var middleDate = startDate.advance(daysBetween.divide(2), 'days')
|
| 55 |
+
var kept_vv = getCloseImages(middleDate, S1_VV).select("VV")
|
| 56 |
+
var kept_vh = getCloseImages(middleDate, S1_VH).select("VH")
|
| 57 |
+
var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()])
|
| 58 |
+
return S1_composite.select(S1_BANDS).add(25.0).divide(25.0) // S1 ranges from -50 to 1
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
// 4. Obtain monthly Sentinel-2 composites
|
| 63 |
+
var S2_BANDS = ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B11", "B12"]
|
| 64 |
+
var S2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(roi).filterDate(rangeStart, rangeEnd)
|
| 65 |
+
var csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED').filterBounds(roi).filterDate(rangeStart, rangeEnd)
|
| 66 |
+
var QA_BAND = 'cs_cdf'; // Better than cs here
|
| 67 |
+
var S2_cf = S2.linkCollection(csPlus, [QA_BAND])
|
| 68 |
+
|
| 69 |
+
function S2_img(date1, date2){
|
| 70 |
+
return S2_cf.filterDate(date1, date2)
|
| 71 |
+
.qualityMosaic(QA_BAND)
|
| 72 |
+
.select(S2_BANDS)
|
| 73 |
+
.divide(ee.Image(1e4))
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
// 5. Obtain monthly ERA5 composites
|
| 78 |
+
var ERA5_BANDS = ["temperature_2m", "total_precipitation_sum"]
|
| 79 |
+
var ERA5 = ee.ImageCollection("ECMWF/ERA5_LAND/MONTHLY_AGGR").filterBounds(roi).filterDate(rangeStart, rangeEnd)
|
| 80 |
+
function ERA5_img(date1, date2){
|
| 81 |
+
return ERA5.filterDate(date1, date2)
|
| 82 |
+
.select(ERA5_BANDS)
|
| 83 |
+
.mean()
|
| 84 |
+
.add([-272.15, 0]).divide([35, 0.03])
|
| 85 |
+
}
|
| 86 |
+
//var ERA5_temp = ee.Image([0,0]).rename(ERA5_BANDS).clip(roi)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
// 6. Obtain SRTM Data
|
| 90 |
+
var SRTM_BANDS = ["elevation", "slope"]
|
| 91 |
+
var elevation = ee.Image("USGS/SRTMGL1_003").clip(roi).select("elevation")
|
| 92 |
+
var slope = ee.Terrain.slope(elevation)
|
| 93 |
+
var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50])
|
| 94 |
+
//var SRTM_temp = ee.Image([0,0]).rename(SRTM_BANDS).clip(roi)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
// 7. Combine all data into a monthly CropHarvest-style monthly composite
|
| 98 |
+
function cropharvest_img(d1, d2){
|
| 99 |
+
var img = ee.Image.cat([S1_img(d1, d2), S2_img(d1, d2), ERA5_img(d1, d2), SRTM_img])
|
| 100 |
+
var ndvi = img.normalizedDifference(['B8', 'B4']).rename("NDVI")
|
| 101 |
+
// toFloat Necessary for tensor conversion
|
| 102 |
+
return img.addBands(ndvi).clip(roi).toFloat()
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
// 8. Create and visualize Presto input
|
| 107 |
+
var latlons = ee.Image.pixelLonLat().clip(roi).select("latitude", "longitude")
|
| 108 |
+
var imgs = [latlons]
|
| 109 |
+
var numMonths = rangeEnd.difference(rangeStart, 'month').toInt().getInfo()
|
| 110 |
+
var ERA5Palette = [
|
| 111 |
+
'000080', '0000d9', '4000ff', '8000ff', '0080ff', '00ffff', '00ff80', '80ff00', 'daff00',
|
| 112 |
+
'ffff00', 'fff500', 'ffda00','ffb000', 'ffa400', 'ff4f00', 'ff2500', 'ff0a00', 'ff00ff'
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
for (var i = 0; i < numMonths; i++){
|
| 116 |
+
var monthStart = rangeStart.advance(i, 'month')
|
| 117 |
+
var monthEnd = monthStart.advance(1, 'month')
|
| 118 |
+
var img = cropharvest_img(monthStart, monthEnd)
|
| 119 |
+
imgs.push(img)
|
| 120 |
+
|
| 121 |
+
var monthName = monthStart.format("YY/MM").getInfo()
|
| 122 |
+
Map.addLayer(img, {bands: ["VV", "VH", "VV"], min: [0, -0.2, 0.4], max: [1.0, 0.8, 1.2]}, monthName + " S1", false)
|
| 123 |
+
Map.addLayer(img, {bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }, monthName + " S2", false)
|
| 124 |
+
Map.addLayer(img, {bands: ['temperature_2m'], min: 0, max: 1, palette: ERA5Palette}, monthName + " ERA5", false)
|
| 125 |
+
}
|
| 126 |
+
Map.addLayer(imgs[1], {bands: ["slope"], min: 0, max: 0.3 }, "SRTM", false)
|
| 127 |
+
|
| 128 |
+
var composite = ee.ImageCollection.fromImages(imgs).toBands()
|
| 129 |
+
|
| 130 |
+
// 9. Make predictions using Presto on Vertex AI
|
| 131 |
+
var vertex_model = ee.Model.fromVertexAi({
|
| 132 |
+
endpoint: ENDPOINT,
|
| 133 |
+
inputTileSize: [1,1],
|
| 134 |
+
proj: ee.Projection('EPSG:4326').atScale(10),
|
| 135 |
+
fixInputProj: true,
|
| 136 |
+
outputTileSize: [1,1],
|
| 137 |
+
outputBands: {'p': { 'type': ee.PixelType.float(), 'dimensions': 1}},
|
| 138 |
+
payloadFormat: 'ND_ARRAYS',
|
| 139 |
+
maxPayloadBytes: 5242880 // 5.24mb [MAX]
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
if (RUN_VERTEX_AI){
|
| 143 |
+
|
| 144 |
+
// Create band names for embeddingsArrayImage
|
| 145 |
+
var bandNames = []
|
| 146 |
+
for (var i=0; i<128; i++){ bandNames.push("b" + i + "") }
|
| 147 |
+
|
| 148 |
+
// embeddingsArrayImage is a single band image where each pixel contains an array
|
| 149 |
+
var embeddingsArrayImage = vertex_model.predictImage(composite).clip(roi)
|
| 150 |
+
var embeddingsMultiBandImage = embeddingsArrayImage.arrayFlatten([bandNames])
|
| 151 |
+
|
| 152 |
+
// Only smaller size embeddings can be directly viewed in GEE immediatley larger ones require the batch task
|
| 153 |
+
// Map.addLayer(embeddingsMultiBandImage, {min: 0, max: 1},'embeddingsMultiBandImage')
|
| 154 |
+
|
| 155 |
+
Export.image.toAsset({
|
| 156 |
+
image: embeddingsMultiBandImage,
|
| 157 |
+
description: 'Presto_embeddings',
|
| 158 |
+
assetId: 'Togo/Presto_test_embeddings_v2025_04_23',
|
| 159 |
+
region: roi,
|
| 160 |
+
scale: 10,
|
| 161 |
+
maxPixels: 1e12,
|
| 162 |
+
crs: 'EPSG:25231'
|
| 163 |
+
});
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
google_earth/1d196e8466506239c4780585c0e28d26/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26
|