{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "## A01 - Simple usage of pretrained SatCLIP embeddings\n", "\n", "To obtained pretrained **SatCLIP** embeddings, first install the repository." ], "metadata": { "id": "ngz8zz9Gvbxh" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tD7wze7andRh", "outputId": "a8f55f89-e313-4792-c34b-c33026ed2dc0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into '.'...\n", "remote: Enumerating objects: 131, done.\u001b[K\n", "remote: Counting objects: 100% (36/36), done.\u001b[K\n", "remote: Compressing objects: 100% (32/32), done.\u001b[K\n", "remote: Total 131 (delta 14), reused 16 (delta 4), pack-reused 95\u001b[K\n", "Receiving objects: 100% (131/131), 9.01 MiB | 21.25 MiB/s, done.\n", "Resolving deltas: 100% (48/48), done.\n" ] } ], "source": [ "!rm -r sample_data .config # Empty current directory\n", "!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository" ] }, { "cell_type": "markdown", "source": [ "Now install the required Python packages." ], "metadata": { "id": "hOEZnNt1v2Bl" } }, { "cell_type": "code", "source": [ "!pip install lightning --quiet\n", "!pip install rasterio --quiet\n", "!pip install torchgeo --quiet" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Q72Ypu0Cr3Sc", "outputId": "d54f24b8-26b1-4f86-ffc8-23da0cc63710" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m806.1/806.1 kB\u001b[0m \u001b[31m40.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m776.9/776.9 kB\u001b[0m \u001b[31m42.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.6/20.6 MB\u001b[0m \u001b[31m27.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m378.5/378.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m705.7/705.7 kB\u001b[0m \u001b[31m37.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m670.7/670.7 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m488.6/488.6 kB\u001b[0m \u001b[31m41.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m106.7/106.7 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m66.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.5/154.5 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m137.6/137.6 kB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m189.7/189.7 kB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.5/79.5 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.7/101.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.0/117.0 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m594.2/594.2 kB\u001b[0m \u001b[31m44.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for pretrainedmodels (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ] }, { "cell_type": "markdown", "source": [ "You can find a list of pretrained SatCLIP models [here](https://github.com/microsoft/satclip#pretrained-models). Let's download a SatCLIP using a ResNet18 vision encoder and $L=10$ Legendre polynomials for spherical harmonics calculation in the location encoder." ], "metadata": { "id": "pmcZKqU9wOTk" } }, { "cell_type": "code", "source": [ "!wget 'https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l10.ckpt'" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "N4S7_2fqqWF1", "outputId": "6875c479-e70e-488f-bb1d-ccbca0105ba1" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "--2023-12-07 14:55:28-- https://satclip.z13.web.core.windows.net/satclip/satclip-resnet18-l10.ckpt\n", "Resolving satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)... 52.239.221.231\n", "Connecting to satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)|52.239.221.231|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 57201927 (55M) [application/zip]\n", "Saving to: ‘satclip-resnet18-l10.ckpt’\n", "\n", "satclip-resnet18-l1 100%[===================>] 54.55M 91.7MB/s in 0.6s \n", "\n", "2023-12-07 14:55:29 (91.7 MB/s) - ‘satclip-resnet18-l10.ckpt’ saved [57201927/57201927]\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "Load required packages." ], "metadata": { "id": "P_FUpXihwwpB" } }, { "cell_type": "code", "source": [ "import sys\n", "sys.path.append('./satclip')\n", "\n", "import torch\n", "from load import get_satclip" ], "metadata": { "id": "grEIwoFjoHvu" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "And now we can sample some random lat/lon locations for which we obtain SatCLIP embeddings." ], "metadata": { "id": "mXdw7j9cw1sK" } }, { "cell_type": "code", "source": [ "satclip_path = 'satclip-resnet18-l10.ckpt'\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "c = torch.randn(32, 2) # Represents a batch of 32 locations (lon/lat)\n", "\n", "model = get_satclip(satclip_path, device=device) # Only loads location encoder by default\n", "model.eval()\n", "with torch.no_grad():\n", " emb = model(c.double().to(device)).detach().cpu()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_oOATwQHn5Pc", "outputId": "97178d9f-adeb-41c5-aa21-7ae22439caa4" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "using pretrained moco resnet18\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Downloading: \"https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth\" to /root/.cache/torch/hub/checkpoints/resnet18_sentinel2_all_moco-59bfdff9.pth\n", "100%|██████████| 42.8M/42.8M [00:00<00:00, 59.4MB/s]\n" ] } ] }, { "cell_type": "code", "source": [ "emb.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DY-yS8cNru7w", "outputId": "5efd6439-3f3a-47f3-f50c-707224069b04" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([32, 256])" ] }, "metadata": {}, "execution_count": 6 } ] } ] }