Cập nhật mô hình, xóa các file không cần thiết
Browse files- Code.ipynb +64 -20
- Code2.ipynb +0 -94
- Code3.ipynb +0 -103
Code.ipynb
CHANGED
|
@@ -6,32 +6,76 @@
|
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
-
"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
"
|
| 20 |
-
"output_type": "
|
|
|
|
|
|
|
|
|
|
| 21 |
}
|
| 22 |
],
|
| 23 |
"source": [
|
| 24 |
-
"
|
| 25 |
-
"from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"\n",
|
| 27 |
-
"#
|
| 28 |
-
"
|
|
|
|
|
|
|
|
|
|
| 29 |
"\n",
|
| 30 |
-
"#
|
| 31 |
-
"
|
| 32 |
"\n",
|
| 33 |
-
"#
|
| 34 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
]
|
| 36 |
}
|
| 37 |
],
|
|
|
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
+
"name": "stderr",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n",
|
| 13 |
+
"- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated\n",
|
| 14 |
+
"- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated\n",
|
| 15 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"name": "stdout",
|
| 20 |
+
"output_type": "stream",
|
| 21 |
+
"text": [
|
| 22 |
+
"Conversion to ONNX completed successfully!\n"
|
| 23 |
+
]
|
| 24 |
}
|
| 25 |
],
|
| 26 |
"source": [
|
| 27 |
+
"import torch\n",
|
| 28 |
+
"from transformers import ViTForImageClassification\n",
|
| 29 |
+
"import torch.nn as nn\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"# 1. Định nghĩa lại lớp mô hình (phải giống hệt khi training)\n",
|
| 32 |
+
"class ViTBinaryClassifier(nn.Module):\n",
|
| 33 |
+
" def __init__(self, pretrained_model=\"google/vit-base-patch16-224\", freeze_base=False):\n",
|
| 34 |
+
" super().__init__()\n",
|
| 35 |
+
" self.vit = ViTForImageClassification.from_pretrained(\n",
|
| 36 |
+
" pretrained_model,\n",
|
| 37 |
+
" num_labels=1,\n",
|
| 38 |
+
" ignore_mismatched_sizes=True\n",
|
| 39 |
+
" )\n",
|
| 40 |
+
" self.vit.classifier = nn.Sequential(\n",
|
| 41 |
+
" nn.Linear(self.vit.config.hidden_size, 256),\n",
|
| 42 |
+
" nn.ReLU(),\n",
|
| 43 |
+
" nn.Dropout(0.1),\n",
|
| 44 |
+
" nn.Linear(256, 1)\n",
|
| 45 |
+
" )\n",
|
| 46 |
+
" if freeze_base:\n",
|
| 47 |
+
" for param in self.vit.vit.parameters():\n",
|
| 48 |
+
" param.requires_grad = False\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" def forward(self, pixel_values):\n",
|
| 51 |
+
" outputs = self.vit(pixel_values)\n",
|
| 52 |
+
" return outputs.logits\n",
|
| 53 |
"\n",
|
| 54 |
+
"# 2. Khởi tạo và load weights\n",
|
| 55 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 56 |
+
"model = ViTBinaryClassifier().to(device)\n",
|
| 57 |
+
"model.load_state_dict(torch.load(r\"D:\\SonCode\\Taosafescan\\vit_binary_classifier_best.pt\", map_location=device))\n",
|
| 58 |
+
"model.eval() # Chuyển sang chế độ inference\n",
|
| 59 |
"\n",
|
| 60 |
+
"# 3. Tạo dummy input với kích thước phù hợp (batch_size, channels, height, width)\n",
|
| 61 |
+
"dummy_input = torch.randn(1, 3, 224, 224).to(device) # Kích thước ảnh 224x224 cho ViT-base\n",
|
| 62 |
"\n",
|
| 63 |
+
"# 4. Xuất sang ONNX\n",
|
| 64 |
+
"torch.onnx.export(\n",
|
| 65 |
+
" model,\n",
|
| 66 |
+
" dummy_input,\n",
|
| 67 |
+
" \"vit_binary_classification.onnx\",\n",
|
| 68 |
+
" export_params=True,\n",
|
| 69 |
+
" opset_version=14, # Thay đổi tại đây\n",
|
| 70 |
+
" do_constant_folding=True,\n",
|
| 71 |
+
" input_names=[\"pixel_values\"],\n",
|
| 72 |
+
" output_names=[\"logits\"],\n",
|
| 73 |
+
" dynamic_axes={\n",
|
| 74 |
+
" \"pixel_values\": {0: \"batch_size\"},\n",
|
| 75 |
+
" \"logits\": {0: \"batch_size\"}\n",
|
| 76 |
+
" }\n",
|
| 77 |
+
")\n",
|
| 78 |
+
"print(\"Conversion to ONNX completed successfully!\")"
|
| 79 |
]
|
| 80 |
}
|
| 81 |
],
|
Code2.ipynb
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"outputs": [
|
| 8 |
-
{
|
| 9 |
-
"name": "stdout",
|
| 10 |
-
"output_type": "stream",
|
| 11 |
-
"text": [
|
| 12 |
-
"Collecting monai\n",
|
| 13 |
-
" Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)\n",
|
| 14 |
-
"Requirement already satisfied: numpy<2.0,>=1.24 in c:\\users\\1\\anaconda3\\lib\\site-packages (from monai) (1.26.4)\n",
|
| 15 |
-
"Requirement already satisfied: torch>=1.9 in c:\\users\\1\\anaconda3\\lib\\site-packages (from monai) (2.6.0+cu126)\n",
|
| 16 |
-
"Requirement already satisfied: filelock in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (3.13.1)\n",
|
| 17 |
-
"Requirement already satisfied: typing-extensions>=4.10.0 in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (4.11.0)\n",
|
| 18 |
-
"Requirement already satisfied: networkx in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (3.3)\n",
|
| 19 |
-
"Requirement already satisfied: jinja2 in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (3.1.4)\n",
|
| 20 |
-
"Requirement already satisfied: fsspec in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (2024.6.1)\n",
|
| 21 |
-
"Requirement already satisfied: setuptools in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (75.1.0)\n",
|
| 22 |
-
"Requirement already satisfied: sympy==1.13.1 in c:\\users\\1\\anaconda3\\lib\\site-packages (from torch>=1.9->monai) (1.13.1)\n",
|
| 23 |
-
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\\users\\1\\anaconda3\\lib\\site-packages (from sympy==1.13.1->torch>=1.9->monai) (1.3.0)\n",
|
| 24 |
-
"Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\1\\anaconda3\\lib\\site-packages (from jinja2->torch>=1.9->monai) (2.1.3)\n",
|
| 25 |
-
"Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)\n",
|
| 26 |
-
" ---------------------------------------- 0.0/1.5 MB ? eta -:--:--\n",
|
| 27 |
-
" ---------------------------------------- 0.0/1.5 MB ? eta -:--:--\n",
|
| 28 |
-
" ---------------------------------------- 0.0/1.5 MB ? eta -:--:--\n",
|
| 29 |
-
" ------ --------------------------------- 0.3/1.5 MB ? eta -:--:--\n",
|
| 30 |
-
" ------------- -------------------------- 0.5/1.5 MB 1.0 MB/s eta 0:00:01\n",
|
| 31 |
-
" -------------------- ------------------- 0.8/1.5 MB 1.3 MB/s eta 0:00:01\n",
|
| 32 |
-
" -------------------- ------------------- 0.8/1.5 MB 1.3 MB/s eta 0:00:01\n",
|
| 33 |
-
" --------------------------- ------------ 1.0/1.5 MB 1.0 MB/s eta 0:00:01\n",
|
| 34 |
-
" --------------------------- ------------ 1.0/1.5 MB 1.0 MB/s eta 0:00:01\n",
|
| 35 |
-
" ---------------------------------- ----- 1.3/1.5 MB 871.6 kB/s eta 0:00:01\n",
|
| 36 |
-
" ---------------------------------- ----- 1.3/1.5 MB 871.6 kB/s eta 0:00:01\n",
|
| 37 |
-
" ---------------------------------- ----- 1.3/1.5 MB 871.6 kB/s eta 0:00:01\n",
|
| 38 |
-
" ---------------------------------------- 1.5/1.5 MB 695.8 kB/s eta 0:00:00\n",
|
| 39 |
-
"Installing collected packages: monai\n",
|
| 40 |
-
"Successfully installed monai-1.4.0\n"
|
| 41 |
-
]
|
| 42 |
-
}
|
| 43 |
-
],
|
| 44 |
-
"source": [
|
| 45 |
-
"!pip install monai"
|
| 46 |
-
]
|
| 47 |
-
},
|
| 48 |
-
{
|
| 49 |
-
"cell_type": "code",
|
| 50 |
-
"execution_count": null,
|
| 51 |
-
"metadata": {},
|
| 52 |
-
"outputs": [],
|
| 53 |
-
"source": [
|
| 54 |
-
"from monai.transforms import (\n",
|
| 55 |
-
" Compose,\n",
|
| 56 |
-
" LoadImage,\n",
|
| 57 |
-
" ScaleIntensity,\n",
|
| 58 |
-
" AddChannel\n",
|
| 59 |
-
")\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"# Define transforms for image preprocessing\n",
|
| 62 |
-
"transforms = Compose([\n",
|
| 63 |
-
" LoadImage(image_only=True),\n",
|
| 64 |
-
" AddChannel(),\n",
|
| 65 |
-
" ScaleIntensity()\n",
|
| 66 |
-
"])\n",
|
| 67 |
-
"\n",
|
| 68 |
-
"# Apply transforms to your image\n",
|
| 69 |
-
"image = transforms(image_path)"
|
| 70 |
-
]
|
| 71 |
-
}
|
| 72 |
-
],
|
| 73 |
-
"metadata": {
|
| 74 |
-
"kernelspec": {
|
| 75 |
-
"display_name": "base",
|
| 76 |
-
"language": "python",
|
| 77 |
-
"name": "python3"
|
| 78 |
-
},
|
| 79 |
-
"language_info": {
|
| 80 |
-
"codemirror_mode": {
|
| 81 |
-
"name": "ipython",
|
| 82 |
-
"version": 3
|
| 83 |
-
},
|
| 84 |
-
"file_extension": ".py",
|
| 85 |
-
"mimetype": "text/x-python",
|
| 86 |
-
"name": "python",
|
| 87 |
-
"nbconvert_exporter": "python",
|
| 88 |
-
"pygments_lexer": "ipython3",
|
| 89 |
-
"version": "3.12.3"
|
| 90 |
-
}
|
| 91 |
-
},
|
| 92 |
-
"nbformat": 4,
|
| 93 |
-
"nbformat_minor": 2
|
| 94 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Code3.ipynb
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"outputs": [
|
| 8 |
-
{
|
| 9 |
-
"name": "stderr",
|
| 10 |
-
"output_type": "stream",
|
| 11 |
-
"text": [
|
| 12 |
-
"Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n",
|
| 13 |
-
"- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated\n",
|
| 14 |
-
"- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated\n",
|
| 15 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 16 |
-
]
|
| 17 |
-
},
|
| 18 |
-
{
|
| 19 |
-
"name": "stdout",
|
| 20 |
-
"output_type": "stream",
|
| 21 |
-
"text": [
|
| 22 |
-
"Conversion to ONNX completed successfully!\n"
|
| 23 |
-
]
|
| 24 |
-
}
|
| 25 |
-
],
|
| 26 |
-
"source": [
|
| 27 |
-
"import torch\n",
|
| 28 |
-
"from transformers import ViTForImageClassification\n",
|
| 29 |
-
"import torch.nn as nn\n",
|
| 30 |
-
"\n",
|
| 31 |
-
"# 1. Định nghĩa lại lớp mô hình (phải giống hệt khi training)\n",
|
| 32 |
-
"class ViTBinaryClassifier(nn.Module):\n",
|
| 33 |
-
" def __init__(self, pretrained_model=\"google/vit-base-patch16-224\", freeze_base=False):\n",
|
| 34 |
-
" super().__init__()\n",
|
| 35 |
-
" self.vit = ViTForImageClassification.from_pretrained(\n",
|
| 36 |
-
" pretrained_model,\n",
|
| 37 |
-
" num_labels=1,\n",
|
| 38 |
-
" ignore_mismatched_sizes=True\n",
|
| 39 |
-
" )\n",
|
| 40 |
-
" self.vit.classifier = nn.Sequential(\n",
|
| 41 |
-
" nn.Linear(self.vit.config.hidden_size, 256),\n",
|
| 42 |
-
" nn.ReLU(),\n",
|
| 43 |
-
" nn.Dropout(0.1),\n",
|
| 44 |
-
" nn.Linear(256, 1)\n",
|
| 45 |
-
" )\n",
|
| 46 |
-
" if freeze_base:\n",
|
| 47 |
-
" for param in self.vit.vit.parameters():\n",
|
| 48 |
-
" param.requires_grad = False\n",
|
| 49 |
-
"\n",
|
| 50 |
-
" def forward(self, pixel_values):\n",
|
| 51 |
-
" outputs = self.vit(pixel_values)\n",
|
| 52 |
-
" return outputs.logits\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"# 2. Khởi tạo và load weights\n",
|
| 55 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 56 |
-
"model = ViTBinaryClassifier().to(device)\n",
|
| 57 |
-
"model.load_state_dict(torch.load(r\"D:\\SonCode\\Taosafescan\\vit_binary_classifier_best.pt\", map_location=device))\n",
|
| 58 |
-
"model.eval() # Chuyển sang chế độ inference\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"# 3. Tạo dummy input với kích thước phù hợp (batch_size, channels, height, width)\n",
|
| 61 |
-
"dummy_input = torch.randn(1, 3, 224, 224).to(device) # Kích thước ảnh 224x224 cho ViT-base\n",
|
| 62 |
-
"\n",
|
| 63 |
-
"# 4. Xuất sang ONNX\n",
|
| 64 |
-
"torch.onnx.export(\n",
|
| 65 |
-
" model,\n",
|
| 66 |
-
" dummy_input,\n",
|
| 67 |
-
" \"vit_binary_classification.onnx\",\n",
|
| 68 |
-
" export_params=True,\n",
|
| 69 |
-
" opset_version=14, # Thay đổi tại đây\n",
|
| 70 |
-
" do_constant_folding=True,\n",
|
| 71 |
-
" input_names=[\"pixel_values\"],\n",
|
| 72 |
-
" output_names=[\"logits\"],\n",
|
| 73 |
-
" dynamic_axes={\n",
|
| 74 |
-
" \"pixel_values\": {0: \"batch_size\"},\n",
|
| 75 |
-
" \"logits\": {0: \"batch_size\"}\n",
|
| 76 |
-
" }\n",
|
| 77 |
-
")\n",
|
| 78 |
-
"print(\"Conversion to ONNX completed successfully!\")"
|
| 79 |
-
]
|
| 80 |
-
}
|
| 81 |
-
],
|
| 82 |
-
"metadata": {
|
| 83 |
-
"kernelspec": {
|
| 84 |
-
"display_name": "base",
|
| 85 |
-
"language": "python",
|
| 86 |
-
"name": "python3"
|
| 87 |
-
},
|
| 88 |
-
"language_info": {
|
| 89 |
-
"codemirror_mode": {
|
| 90 |
-
"name": "ipython",
|
| 91 |
-
"version": 3
|
| 92 |
-
},
|
| 93 |
-
"file_extension": ".py",
|
| 94 |
-
"mimetype": "text/x-python",
|
| 95 |
-
"name": "python",
|
| 96 |
-
"nbconvert_exporter": "python",
|
| 97 |
-
"pygments_lexer": "ipython3",
|
| 98 |
-
"version": "3.12.3"
|
| 99 |
-
}
|
| 100 |
-
},
|
| 101 |
-
"nbformat": 4,
|
| 102 |
-
"nbformat_minor": 2
|
| 103 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|