{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n", "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated\n", "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Conversion to ONNX completed successfully!\n" ] } ], "source": [ "import torch\n", "from transformers import ViTForImageClassification\n", "import torch.nn as nn\n", "\n", "# 1. Định nghĩa lại lớp mô hình (phải giống hệt khi training)\n", "class ViTBinaryClassifier(nn.Module):\n", " def __init__(self, pretrained_model=\"google/vit-base-patch16-224\", freeze_base=False):\n", " super().__init__()\n", " self.vit = ViTForImageClassification.from_pretrained(\n", " pretrained_model,\n", " num_labels=1,\n", " ignore_mismatched_sizes=True\n", " )\n", " self.vit.classifier = nn.Sequential(\n", " nn.Linear(self.vit.config.hidden_size, 256),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(256, 1)\n", " )\n", " if freeze_base:\n", " for param in self.vit.vit.parameters():\n", " param.requires_grad = False\n", "\n", " def forward(self, pixel_values):\n", " outputs = self.vit(pixel_values)\n", " return outputs.logits\n", "\n", "# 2. Khởi tạo và load weights\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = ViTBinaryClassifier().to(device)\n", "model.load_state_dict(torch.load(r\"D:\\SonCode\\Taosafescan\\vit_binary_classifier_best.pt\", map_location=device))\n", "model.eval() # Chuyển sang chế độ inference\n", "\n", "# 3. Tạo dummy input với kích thước phù hợp (batch_size, channels, height, width)\n", "dummy_input = torch.randn(1, 3, 224, 224).to(device) # Kích thước ảnh 224x224 cho ViT-base\n", "\n", "# 4. Xuất sang ONNX\n", "torch.onnx.export(\n", " model,\n", " dummy_input,\n", " \"vit_binary_classification.onnx\",\n", " export_params=True,\n", " opset_version=14, # Thay đổi tại đây\n", " do_constant_folding=True,\n", " input_names=[\"pixel_values\"],\n", " output_names=[\"logits\"],\n", " dynamic_axes={\n", " \"pixel_values\": {0: \"batch_size\"},\n", " \"logits\": {0: \"batch_size\"}\n", " }\n", ")\n", "print(\"Conversion to ONNX completed successfully!\")" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }