File size: 3,702 Bytes
07abaf6 14ed030 07abaf6 14ed030 07abaf6 14ed030 07abaf6 14ed030 07abaf6 14ed030 07abaf6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n",
"- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated\n",
"- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Conversion to ONNX completed successfully!\n"
]
}
],
"source": [
"import torch\n",
"from transformers import ViTForImageClassification\n",
"import torch.nn as nn\n",
"\n",
"# 1. Định nghĩa lại lớp mô hình (phải giống hệt khi training)\n",
"class ViTBinaryClassifier(nn.Module):\n",
" def __init__(self, pretrained_model=\"google/vit-base-patch16-224\", freeze_base=False):\n",
" super().__init__()\n",
" self.vit = ViTForImageClassification.from_pretrained(\n",
" pretrained_model,\n",
" num_labels=1,\n",
" ignore_mismatched_sizes=True\n",
" )\n",
" self.vit.classifier = nn.Sequential(\n",
" nn.Linear(self.vit.config.hidden_size, 256),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.1),\n",
" nn.Linear(256, 1)\n",
" )\n",
" if freeze_base:\n",
" for param in self.vit.vit.parameters():\n",
" param.requires_grad = False\n",
"\n",
" def forward(self, pixel_values):\n",
" outputs = self.vit(pixel_values)\n",
" return outputs.logits\n",
"\n",
"# 2. Khởi tạo và load weights\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = ViTBinaryClassifier().to(device)\n",
"model.load_state_dict(torch.load(r\"D:\\SonCode\\Taosafescan\\vit_binary_classifier_best.pt\", map_location=device))\n",
"model.eval() # Chuyển sang chế độ inference\n",
"\n",
"# 3. Tạo dummy input với kích thước phù hợp (batch_size, channels, height, width)\n",
"dummy_input = torch.randn(1, 3, 224, 224).to(device) # Kích thước ảnh 224x224 cho ViT-base\n",
"\n",
"# 4. Xuất sang ONNX\n",
"torch.onnx.export(\n",
" model,\n",
" dummy_input,\n",
" \"vit_binary_classification.onnx\",\n",
" export_params=True,\n",
" opset_version=14, # Thay đổi tại đây\n",
" do_constant_folding=True,\n",
" input_names=[\"pixel_values\"],\n",
" output_names=[\"logits\"],\n",
" dynamic_axes={\n",
" \"pixel_values\": {0: \"batch_size\"},\n",
" \"logits\": {0: \"batch_size\"}\n",
" }\n",
")\n",
"print(\"Conversion to ONNX completed successfully!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|