File size: 2,823 Bytes
e5a23d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"Follow this step for use SuperMiniVIT\n",
"----------------------------------------"
],
"metadata": {
"id": "kFgTthdhoCFa"
}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "_mE-O3tynzyy"
},
"outputs": [],
"source": [
"import torch\n",
"from MiniVIT import *"
]
},
{
"cell_type": "code",
"source": [
"model = MiniVisualTransformers()"
],
"metadata": {
"id": "pCLQjtDxoPI0"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"checkpoint = torch.load('VIT_Encoder.pth')\n",
"model.load_state_dict(checkpoint)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "R-EJhGpbolbl",
"outputId": "3dc9d1aa-49eb-4516-b16c-2305677bec63"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"model(torch.randn(1,3,144,144))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "QFXaAujfouIi",
"outputId": "6a7ef31e-ef73-4f08-b68d-cd635174057c"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[[-0.1631, 1.8401, 1.4494, ..., -1.2810, -2.9000, 0.1974],\n",
" [-0.7426, 1.0433, 0.3615, ..., -1.4665, -0.2818, 2.2017],\n",
" [ 1.3605, 1.0501, 0.9630, ..., 2.9057, 1.3372, -2.2445],\n",
" ...,\n",
" [-1.6320, 0.7411, -0.3816, ..., -1.9780, 1.6325, -0.0490],\n",
" [-1.2490, -0.6153, 0.8643, ..., -0.8104, 1.2853, -0.2412],\n",
" [-1.7517, -1.4150, -0.2602, ..., -2.3606, -0.3698, 1.9745]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"metadata": {},
"execution_count": 7
}
]
}
]
} |