{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms, models\n", "from PIL import Image\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "#load id_2_class.json\n", "import json\n", "\n", "with open('id_2_class_89.json') as json_file:\n", " id_2_class = json.load(json_file)\n", "\n", "#make class_2_id dict\n", "\n", "class_2_id = {}\n", "for key, value in id_2_class.items():\n", " class_2_id[value] = key" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "test_transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])\n", "])\n", "\n", "class MaxViT(nn.Module):\n", " def __init__(self):\n", " super(MaxViT, self).__init__()\n", " model = models.maxvit_t(weights=\"DEFAULT\")\n", " num_ftrs = model.classifier[5].in_features\n", " model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id))\n", " self.model = model\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", "# Instantiate the model\n", "model = MaxViT()\n", "model.load_state_dict(torch.load('best_model_89.pth'))\n", "model.eval()\n", "\n", "def inference(image_path, CONFIDENT_THRESHOLD=None):\n", " img = Image.open(image_path).convert(\"L\").convert(\"RGB\")\n", " img = test_transform(img)\n", " img = img.unsqueeze(0)\n", "\n", " with torch.no_grad():\n", " output = F.softmax(model(img), dim=1)\n", " confidence, predicted = torch.max(output.data, 1)\n", "\n", " if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD:\n", " return \"UNKNOWN_CLASS\", confidence.item()\n", " \n", " return id_2_class[str(predicted.item())], confidence.item()\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('Volume', 0.9951752424240112)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inference(\"images/7820.jpg\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('Volume', 0.9951752424240112)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inference(\"images/7820.jpg\", 0.9) #0.9 should be good enough" ] } ], "metadata": { "kernelspec": { "display_name": "textgen", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 2 }