{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "## C02 - Use GeoCLIP embeddings\n", "\n", "Simple example of how to obtain pretrained GeoCLIP embeddings. Read the paper here:[https://arxiv.org/abs/2309.16020](https://arxiv.org/abs/2309.16020). First install the geoclip package (see [https://github.com/VicenteVivan/geo-clip](https://github.com/VicenteVivan/geo-clip))" ], "metadata": { "id": "ngz8zz9Gvbxh" } }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tD7wze7andRh", "outputId": "c49cc55d-9eab-452d-cfd1-50975664678f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting geoclip\n", " Downloading geoclip-1.1.0-py3-none-any.whl (40.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.3/40.3 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from geoclip) (2.1.0+cu121)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from geoclip) (0.16.0+cu121)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from geoclip) (9.4.0)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from geoclip) (4.35.2)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from geoclip) (1.5.3)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from geoclip) (1.23.5)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->geoclip) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->geoclip) (2023.3.post1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (3.13.1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (4.5.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (2023.6.0)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->geoclip) (2.1.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision->geoclip) (2.31.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (0.20.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (2023.6.3)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (0.15.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (0.4.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->geoclip) (4.66.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->geoclip) (1.16.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->geoclip) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->geoclip) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->geoclip) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->geoclip) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision->geoclip) (2023.11.17)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->geoclip) (1.3.0)\n", "Installing collected packages: geoclip\n", "Successfully installed geoclip-1.1.0\n" ] } ], "source": [ "!pip install geoclip" ] }, { "cell_type": "markdown", "source": [ "Load the pretrained model directly." ], "metadata": { "id": "nPc2yi39yRje" } }, { "cell_type": "code", "source": [ "from geoclip import LocationEncoder\n", "import torch\n", "import torch.nn as nn\n", "model = LocationEncoder()" ], "metadata": { "id": "Q72Ypu0Cr3Sc" }, "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "source": [ "Obtain GeoCLIP location embeddings." ], "metadata": { "id": "Y8XaPTs6yUu9" } }, { "cell_type": "code", "source": [ "c = torch.randn(32, 2) # Represents a batch of 32 locations (lon/lat)\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " emb = model(c.flip(1).float()).detach().cpu()" ], "metadata": { "id": "eWV6S2SmsX_O" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "emb.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3rib-U9ztCCg", "outputId": "67cd8c6e-d53d-4317-b5aa-9076df069e96" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([32, 512])" ] }, "metadata": {}, "execution_count": 8 } ] } ] }