{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "## C01 - Use CSP embeddings\n", "\n", "Simple example of how to obtain pretrained CSP embeddings. Read the paper here:[https://arxiv.org/abs/2305.01118](https://arxiv.org/abs/2305.01118). Note that this notebook needs to be run with GPU enabled. To do this got to: \"Runtime -> Change runtime type\"" ], "metadata": { "id": "ngz8zz9Gvbxh" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tD7wze7andRh", "outputId": "a2cc59f6-00e8-4c73-b95f-0fe737791593" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into '.'...\n", "remote: Enumerating objects: 86, done.\u001b[K\n", "remote: Counting objects: 100% (86/86), done.\u001b[K\n", "remote: Compressing objects: 100% (69/69), done.\u001b[K\n", "remote: Total 86 (delta 26), reused 75 (delta 15), pack-reused 0\u001b[K\n", "Receiving objects: 100% (86/86), 1.58 MiB | 16.53 MiB/s, done.\n", "Resolving deltas: 100% (26/26), done.\n" ] } ], "source": [ "!rm -r sample_data .config # Empty current directory\n", "!git clone https://github.com/gengchenmai/csp.git . # Clone CSP repository" ] }, { "cell_type": "markdown", "source": [ "Import required packages." ], "metadata": { "id": "drQnlZEDwBvA" } }, { "cell_type": "code", "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "import sys\n", "sys.path.append('./main')\n", "\n", "from main.utils import *\n", "from main.models import *" ], "metadata": { "id": "Q72Ypu0Cr3Sc" }, "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "source": [ "Write helper function to load CPS models from checkpoint." ], "metadata": { "id": "UUi_LduKwDpA" } }, { "cell_type": "code", "source": [ "def get_csp(path):\n", " pretrained_csp = torch.load(path)\n", "\n", " params = pretrained_csp['params']\n", " loc_enc = get_model(\n", " train_locs = None,\n", " params = params,\n", " spa_enc_type = params['spa_enc_type'],\n", " num_inputs = params['num_loc_feats'],\n", " num_classes = params['num_classes'],\n", " num_filts = params['num_filts'],\n", " num_users = params['num_users'],\n", " device = params['device'])\n", "\n", " model = LocationImageEncoder(loc_enc = loc_enc,\n", " train_loss = params[\"train_loss\"],\n", " unsuper_loss = params[\"unsuper_loss\"],\n", " cnn_feat_dim = params[\"cnn_feat_dim\"],\n", " spa_enc_type = params[\"spa_enc_type\"]).to(params['device'])\n", "\n", " model.load_state_dict(pretrained_csp['state_dict'])\n", "\n", " return model" ], "metadata": { "id": "eWV6S2SmsX_O" }, "execution_count": 3, "outputs": [] }, { "cell_type": "markdown", "source": [ "Download pretrained models. For details see here: [https://gengchenmai.github.io/csp-website/](https://gengchenmai.github.io/csp-website/)" ], "metadata": { "id": "IUtwfnVKwNsu" } }, { "cell_type": "code", "source": [ "!wget -O model_dir.zip 'https://www.dropbox.com/s/qxr644rj1qxekn2/model_dir.zip?dl=1'" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3rib-U9ztCCg", "outputId": "fd0979ce-7bd3-4882-ebfa-ca1c1533a4aa" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "--2024-01-21 01:57:39-- https://www.dropbox.com/s/qxr644rj1qxekn2/model_dir.zip?dl=1\n", "Resolving www.dropbox.com (www.dropbox.com)... 162.125.72.18, 2620:100:6021:18::a27d:4112\n", "Connecting to www.dropbox.com (www.dropbox.com)|162.125.72.18|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: /s/dl/qxr644rj1qxekn2/model_dir.zip [following]\n", "--2024-01-21 01:57:39-- https://www.dropbox.com/s/dl/qxr644rj1qxekn2/model_dir.zip\n", "Reusing existing connection to www.dropbox.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com/cd/0/get/CLsPN3FzQu8_q5uGtPvkDlDROOyJ6xA-YyXdunQrn7KAzziriwxL6D5CtJ-GLtERcoOgOxXfm552bAerCPxoiiK73Gcy-_d58UA0Im4LxotstFW-4wLSnulKjrUuv0l_OrxCIyAXp_GP2OHpWKHcgdQN/file?dl=1# [following]\n", "--2024-01-21 01:57:40-- https://uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com/cd/0/get/CLsPN3FzQu8_q5uGtPvkDlDROOyJ6xA-YyXdunQrn7KAzziriwxL6D5CtJ-GLtERcoOgOxXfm552bAerCPxoiiK73Gcy-_d58UA0Im4LxotstFW-4wLSnulKjrUuv0l_OrxCIyAXp_GP2OHpWKHcgdQN/file?dl=1\n", "Resolving uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com (uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com)... 162.125.65.15, 2620:100:6021:15::a27d:410f\n", "Connecting to uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com (uc88d2f6a29caa94bff34c8e0b68.dl.dropboxusercontent.com)|162.125.65.15|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 94186705 (90M) [application/binary]\n", "Saving to: ‘model_dir.zip’\n", "\n", "model_dir.zip 100%[===================>] 89.82M 17.4MB/s in 5.5s \n", "\n", "2024-01-21 01:57:46 (16.5 MB/s) - ‘model_dir.zip’ saved [94186705/94186705]\n", "\n" ] } ] }, { "cell_type": "code", "source": [ "!unzip model_dir.zip" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v3IDc8C9tZZr", "outputId": "882a9228-af13-4849-adf8-179f055b47ca" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Archive: model_dir.zip\n", " creating: model_dir/\n", " inflating: __MACOSX/._model_dir \n", " inflating: model_dir/.DS_Store \n", " inflating: __MACOSX/model_dir/._.DS_Store \n", " creating: model_dir/model_inat_2018/\n", " inflating: __MACOSX/model_dir/._model_inat_2018 \n", " creating: model_dir/model_fmow/\n", " inflating: __MACOSX/model_dir/._model_fmow \n", " inflating: model_dir/model_inat_2018/.DS_Store \n", " inflating: __MACOSX/model_dir/model_inat_2018/._.DS_Store \n", " inflating: model_dir/model_inat_2018/model_inat_2018_gridcell_0.0010_32_0.1000000_1_512_leakyrelu_contsoftmax_ratio0.050_0.000500_1.000_1_1.000_TMP20.0000_1.0000_1.0000.pth.tar \n", " inflating: __MACOSX/model_dir/model_inat_2018/._model_inat_2018_gridcell_0.0010_32_0.1000000_1_512_leakyrelu_contsoftmax_ratio0.050_0.000500_1.000_1_1.000_TMP20.0000_1.0000_1.0000.pth.tar \n", " inflating: model_dir/model_inat_2018/model_inat_2018_gridcell_0.0010_32_0.1000000_1_512_leakyrelu_UNSUPER-contsoftmax_0.000500_1.000_1_1.000_TMP20.0000_1.0000_1.0000.pth.tar \n", " inflating: __MACOSX/model_dir/model_inat_2018/._model_inat_2018_gridcell_0.0010_32_0.1000000_1_512_leakyrelu_UNSUPER-contsoftmax_0.000500_1.000_1_1.000_TMP20.0000_1.0000_1.0000.pth.tar \n", " inflating: model_dir/model_fmow/.DS_Store \n", " inflating: __MACOSX/model_dir/model_fmow/._.DS_Store \n", " inflating: model_dir/model_fmow/model_fmow_gridcell_0.0010_32_0.1000000_1_512_gelu_UNSUPER-contsoftmax_0.000050_1.000_1_0.100_TMP1.0000_1.0000_1.0000.pth.tar \n", " inflating: __MACOSX/model_dir/model_fmow/._model_fmow_gridcell_0.0010_32_0.1000000_1_512_gelu_UNSUPER-contsoftmax_0.000050_1.000_1_0.100_TMP1.0000_1.0000_1.0000.pth.tar \n", " inflating: model_dir/model_fmow/model_fmow_gridcell_0.0010_32_0.1000000_1_512_gelu_contsoftmax_ratio0.050_0.000050_1.000_1_0.100_TMP1.0000_1.0000_1.0000.pth.tar \n", " inflating: __MACOSX/model_dir/model_fmow/._model_fmow_gridcell_0.0010_32_0.1000000_1_512_gelu_contsoftmax_ratio0.050_0.000050_1.000_1_0.100_TMP1.0000_1.0000_1.0000.pth.tar \n" ] } ] }, { "cell_type": "markdown", "source": [ "Load CSP model." ], "metadata": { "id": "orlY0u8owb7w" } }, { "cell_type": "code", "source": [ "path = '/content/model_dir/model_fmow/model_fmow_gridcell_0.0010_32_0.1000000_1_512_gelu_UNSUPER-contsoftmax_0.000050_1.000_1_0.100_TMP1.0000_1.0000_1.0000.pth.tar'\n", "model = get_csp(path)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0HoKFM2atxM2", "outputId": "3c43c60d-9a5e-4b67-b3bb-75601470fe98" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/content/./main/module.py:98: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n", " nn.init.xavier_uniform(self.linear.weight)\n" ] } ] }, { "cell_type": "markdown", "source": [ "Use CSP model to obtain location embeddings." ], "metadata": { "id": "MBWysaAewdpb" } }, { "cell_type": "code", "source": [ "c = torch.randn(32, 2) # Represents a batch of 32 locations (lon/lat)\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " emb = model.loc_enc(convert_loc_to_tensor(c.numpy()),return_feats=True).detach().cpu()" ], "metadata": { "id": "Ku9kf_0su0id" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "emb.shape" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MYtxk8NCvr0M", "outputId": "33647a0f-772c-4f08-fe51-f77ce7117d26" }, "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([32, 256])" ] }, "metadata": {}, "execution_count": 8 } ] } ] }