File size: 2,823 Bytes
e5a23d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Follow this step for use SuperMiniVIT\n",
        "----------------------------------------"
      ],
      "metadata": {
        "id": "kFgTthdhoCFa"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "_mE-O3tynzyy"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from MiniVIT import *"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model = MiniVisualTransformers()"
      ],
      "metadata": {
        "id": "pCLQjtDxoPI0"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "checkpoint = torch.load('VIT_Encoder.pth')\n",
        "model.load_state_dict(checkpoint)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "R-EJhGpbolbl",
        "outputId": "3dc9d1aa-49eb-4516-b16c-2305677bec63"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model(torch.randn(1,3,144,144))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "QFXaAujfouIi",
        "outputId": "6a7ef31e-ef73-4f08-b68d-cd635174057c"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[-0.1631,  1.8401,  1.4494,  ..., -1.2810, -2.9000,  0.1974],\n",
              "         [-0.7426,  1.0433,  0.3615,  ..., -1.4665, -0.2818,  2.2017],\n",
              "         [ 1.3605,  1.0501,  0.9630,  ...,  2.9057,  1.3372, -2.2445],\n",
              "         ...,\n",
              "         [-1.6320,  0.7411, -0.3816,  ..., -1.9780,  1.6325, -0.0490],\n",
              "         [-1.2490, -0.6153,  0.8643,  ..., -0.8104,  1.2853, -0.2412],\n",
              "         [-1.7517, -1.4150, -0.2602,  ..., -2.3606, -0.3698,  1.9745]]],\n",
              "       grad_fn=<AddBackward0>)"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    }
  ]
}