Spaces:
Sleeping
Sleeping
sakshamlakhera
commited on
Commit
·
1265dde
1
Parent(s):
b274faf
fixing scripts
Browse files- scripts/CV/Part1.ipynb +17 -41
- scripts/CV/compression.ipynb +17 -35
- scripts/CV/script_onion.ipynb +14 -28
- scripts/CV/script_pear.ipynb +17 -43
- scripts/CV/script_strawberry.ipynb +21 -43
- scripts/CV/script_tomato.ipynb +24 -72
scripts/CV/Part1.ipynb
CHANGED
|
@@ -66,10 +66,10 @@
|
|
| 66 |
"def augment_rotations(X, y):\n",
|
| 67 |
" X_aug = []\n",
|
| 68 |
" y_aug = []\n",
|
| 69 |
-
" for k in [1, 2, 3]:
|
| 70 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
| 71 |
" X_aug.append(X_rot)\n",
|
| 72 |
-
" y_aug.append(y.clone())
|
| 73 |
" return torch.cat(X_aug), torch.cat(y_aug)\n"
|
| 74 |
]
|
| 75 |
},
|
|
@@ -165,7 +165,6 @@
|
|
| 165 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 166 |
" plt.show()\n",
|
| 167 |
"\n",
|
| 168 |
-
"# Display for each class\n",
|
| 169 |
"for class_name, image_array in datasets.items():\n",
|
| 170 |
" show_random_samples(image_array, class_name)\n"
|
| 171 |
]
|
|
@@ -189,7 +188,7 @@
|
|
| 189 |
"\n",
|
| 190 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 191 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 192 |
-
" ax.label_outer()
|
| 193 |
"\n",
|
| 194 |
"plt.tight_layout()\n",
|
| 195 |
"plt.show()\n"
|
|
@@ -305,7 +304,7 @@
|
|
| 305 |
"class_names = list(datasets.keys())\n",
|
| 306 |
"num_classes = len(class_names)\n",
|
| 307 |
"\n",
|
| 308 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
| 309 |
"\n",
|
| 310 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 311 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
@@ -371,7 +370,6 @@
|
|
| 371 |
"from sklearn.model_selection import train_test_split\n",
|
| 372 |
"from torchvision import transforms\n",
|
| 373 |
"\n",
|
| 374 |
-
"# Combine data\n",
|
| 375 |
"X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
|
| 376 |
"y = (\n",
|
| 377 |
" ['onion'] * len(onion_images) +\n",
|
|
@@ -380,16 +378,14 @@
|
|
| 380 |
" ['tomato'] * len(tomato_images)\n",
|
| 381 |
")\n",
|
| 382 |
"\n",
|
| 383 |
-
"# Normalizing image\n",
|
| 384 |
"X = X.astype(np.float32) / 255.0\n",
|
| 385 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
| 386 |
"X_tensor = torch.tensor(X)\n",
|
| 387 |
"\n",
|
| 388 |
"le = LabelEncoder()\n",
|
| 389 |
"y_encoded = le.fit_transform(y)\n",
|
| 390 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 391 |
"\n",
|
| 392 |
-
"# splitting data into 50:25:25 (train, validation, test)\n",
|
| 393 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
| 394 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 395 |
]
|
|
@@ -403,13 +399,10 @@
|
|
| 403 |
"source": [
|
| 404 |
"batch_size = 32\n",
|
| 405 |
"\n",
|
| 406 |
-
"# Create new training dataset and loader\n",
|
| 407 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
| 408 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 409 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 410 |
"\n",
|
| 411 |
-
"# DataLoaders\n",
|
| 412 |
-
"\n",
|
| 413 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 414 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
| 415 |
"test_loader = DataLoader(test_dataset, batch_size=batch_size)"
|
|
@@ -432,9 +425,9 @@
|
|
| 432 |
"metadata": {},
|
| 433 |
"outputs": [],
|
| 434 |
"source": [
|
| 435 |
-
"print(f\"
|
| 436 |
-
"print(f\"
|
| 437 |
-
"print(f\"
|
| 438 |
]
|
| 439 |
},
|
| 440 |
{
|
|
@@ -521,8 +514,6 @@
|
|
| 521 |
" optimizer.step()\n",
|
| 522 |
"\n",
|
| 523 |
" total_train_loss += loss.item()\n",
|
| 524 |
-
"\n",
|
| 525 |
-
" # Track training accuracy\n",
|
| 526 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 527 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 528 |
" train_total += batch_y.size(0)\n",
|
|
@@ -546,7 +537,6 @@
|
|
| 546 |
" val_accuracy = val_correct / val_total\n",
|
| 547 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 548 |
"\n",
|
| 549 |
-
" # After calculating val_accuracy\n",
|
| 550 |
" val_losses.append(validation_loss)\n",
|
| 551 |
" val_accs.append(val_accuracy)\n",
|
| 552 |
"\n",
|
|
@@ -637,7 +627,7 @@
|
|
| 637 |
"\n",
|
| 638 |
"print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
|
| 639 |
"\n",
|
| 640 |
-
"target_names = le.classes_
|
| 641 |
"print(\"\\nClassification Report:\\n\")\n",
|
| 642 |
"print(classification_report(all_targets, all_preds, target_names=target_names))\n",
|
| 643 |
"\n",
|
|
@@ -726,7 +716,7 @@
|
|
| 726 |
" h.remove()\n",
|
| 727 |
"\n",
|
| 728 |
" for layer_name, fmap in activations.items():\n",
|
| 729 |
-
" fmap = fmap.squeeze(0)
|
| 730 |
" num_channels = min(fmap.shape[0], max_channels)\n",
|
| 731 |
"\n",
|
| 732 |
" plt.figure(figsize=(num_channels * 2, 2.5))\n",
|
|
@@ -758,29 +748,22 @@
|
|
| 758 |
" activations[name] = output.detach().cpu()\n",
|
| 759 |
" return hook\n",
|
| 760 |
"\n",
|
| 761 |
-
" # Register hooks for all layers in model.features\n",
|
| 762 |
" hooks = []\n",
|
| 763 |
" for i in range(len(model.features)):\n",
|
| 764 |
" layer = model.features[i]\n",
|
| 765 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 766 |
"\n",
|
| 767 |
" with torch.no_grad():\n",
|
| 768 |
-
" _ = model(image_tensor.unsqueeze(0))
|
| 769 |
"\n",
|
| 770 |
" for h in hooks:\n",
|
| 771 |
" h.remove()\n",
|
| 772 |
"\n",
|
| 773 |
" for layer_name, fmap in activations.items():\n",
|
| 774 |
-
" fmap = fmap.squeeze(0)
|
| 775 |
-
"\n",
|
| 776 |
-
" # Compute mean activation per channel\n",
|
| 777 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
| 778 |
-
"\n",
|
| 779 |
-
" # Get indices of top-k channels\n",
|
| 780 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 781 |
" top_indices = topk.indices\n",
|
| 782 |
-
"\n",
|
| 783 |
-
" # Plot top-k channels\n",
|
| 784 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 785 |
" for idx, ch in enumerate(top_indices):\n",
|
| 786 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
@@ -821,14 +804,11 @@
|
|
| 821 |
"\n",
|
| 822 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 823 |
"\n",
|
| 824 |
-
"# Preprocessing (must match model requirements)\n",
|
| 825 |
"transform = transforms.Compose([\n",
|
| 826 |
" transforms.Resize((224, 224)),\n",
|
| 827 |
" transforms.ToTensor()\n",
|
| 828 |
"])\n",
|
| 829 |
-
"img_tensor = transform(img)
|
| 830 |
-
"\n",
|
| 831 |
-
"# Visualize feature maps\n",
|
| 832 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 833 |
]
|
| 834 |
},
|
|
@@ -849,14 +829,12 @@
|
|
| 849 |
"source": [
|
| 850 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
|
| 851 |
"\n",
|
| 852 |
-
"# Preprocessing (must match model requirements)\n",
|
| 853 |
"transform = transforms.Compose([\n",
|
| 854 |
" transforms.Resize((224, 224)),\n",
|
| 855 |
" transforms.ToTensor()\n",
|
| 856 |
"])\n",
|
| 857 |
-
"img_tensor = transform(img)
|
| 858 |
"\n",
|
| 859 |
-
"# Visualize feature maps\n",
|
| 860 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 861 |
]
|
| 862 |
},
|
|
@@ -877,14 +855,12 @@
|
|
| 877 |
"source": [
|
| 878 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 879 |
"\n",
|
| 880 |
-
"# Preprocessing (must match model requirements)\n",
|
| 881 |
"transform = transforms.Compose([\n",
|
| 882 |
" transforms.Resize((224, 224)),\n",
|
| 883 |
" transforms.ToTensor()\n",
|
| 884 |
"])\n",
|
| 885 |
-
"img_tensor = transform(img)
|
| 886 |
"\n",
|
| 887 |
-
"# Visualize feature maps\n",
|
| 888 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 889 |
]
|
| 890 |
},
|
|
|
|
| 66 |
"def augment_rotations(X, y):\n",
|
| 67 |
" X_aug = []\n",
|
| 68 |
" y_aug = []\n",
|
| 69 |
+
" for k in [1, 2, 3]: \n",
|
| 70 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
|
| 71 |
" X_aug.append(X_rot)\n",
|
| 72 |
+
" y_aug.append(y.clone())\n",
|
| 73 |
" return torch.cat(X_aug), torch.cat(y_aug)\n"
|
| 74 |
]
|
| 75 |
},
|
|
|
|
| 165 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 166 |
" plt.show()\n",
|
| 167 |
"\n",
|
|
|
|
| 168 |
"for class_name, image_array in datasets.items():\n",
|
| 169 |
" show_random_samples(image_array, class_name)\n"
|
| 170 |
]
|
|
|
|
| 188 |
"\n",
|
| 189 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 190 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 191 |
+
" ax.label_outer()\n",
|
| 192 |
"\n",
|
| 193 |
"plt.tight_layout()\n",
|
| 194 |
"plt.show()\n"
|
|
|
|
| 304 |
"class_names = list(datasets.keys())\n",
|
| 305 |
"num_classes = len(class_names)\n",
|
| 306 |
"\n",
|
| 307 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))\n",
|
| 308 |
"\n",
|
| 309 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 310 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
|
| 370 |
"from sklearn.model_selection import train_test_split\n",
|
| 371 |
"from torchvision import transforms\n",
|
| 372 |
"\n",
|
|
|
|
| 373 |
"X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
|
| 374 |
"y = (\n",
|
| 375 |
" ['onion'] * len(onion_images) +\n",
|
|
|
|
| 378 |
" ['tomato'] * len(tomato_images)\n",
|
| 379 |
")\n",
|
| 380 |
"\n",
|
|
|
|
| 381 |
"X = X.astype(np.float32) / 255.0\n",
|
| 382 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
| 383 |
"X_tensor = torch.tensor(X)\n",
|
| 384 |
"\n",
|
| 385 |
"le = LabelEncoder()\n",
|
| 386 |
"y_encoded = le.fit_transform(y)\n",
|
| 387 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 388 |
"\n",
|
|
|
|
| 389 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
| 390 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 391 |
]
|
|
|
|
| 399 |
"source": [
|
| 400 |
"batch_size = 32\n",
|
| 401 |
"\n",
|
|
|
|
| 402 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
| 403 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 404 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 405 |
"\n",
|
|
|
|
|
|
|
| 406 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 407 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
| 408 |
"test_loader = DataLoader(test_dataset, batch_size=batch_size)"
|
|
|
|
| 425 |
"metadata": {},
|
| 426 |
"outputs": [],
|
| 427 |
"source": [
|
| 428 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
| 429 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
| 430 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
| 431 |
]
|
| 432 |
},
|
| 433 |
{
|
|
|
|
| 514 |
" optimizer.step()\n",
|
| 515 |
"\n",
|
| 516 |
" total_train_loss += loss.item()\n",
|
|
|
|
|
|
|
| 517 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 518 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 519 |
" train_total += batch_y.size(0)\n",
|
|
|
|
| 537 |
" val_accuracy = val_correct / val_total\n",
|
| 538 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 539 |
"\n",
|
|
|
|
| 540 |
" val_losses.append(validation_loss)\n",
|
| 541 |
" val_accs.append(val_accuracy)\n",
|
| 542 |
"\n",
|
|
|
|
| 627 |
"\n",
|
| 628 |
"print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
|
| 629 |
"\n",
|
| 630 |
+
"target_names = le.classes_ \n",
|
| 631 |
"print(\"\\nClassification Report:\\n\")\n",
|
| 632 |
"print(classification_report(all_targets, all_preds, target_names=target_names))\n",
|
| 633 |
"\n",
|
|
|
|
| 716 |
" h.remove()\n",
|
| 717 |
"\n",
|
| 718 |
" for layer_name, fmap in activations.items():\n",
|
| 719 |
+
" fmap = fmap.squeeze(0)\n",
|
| 720 |
" num_channels = min(fmap.shape[0], max_channels)\n",
|
| 721 |
"\n",
|
| 722 |
" plt.figure(figsize=(num_channels * 2, 2.5))\n",
|
|
|
|
| 748 |
" activations[name] = output.detach().cpu()\n",
|
| 749 |
" return hook\n",
|
| 750 |
"\n",
|
|
|
|
| 751 |
" hooks = []\n",
|
| 752 |
" for i in range(len(model.features)):\n",
|
| 753 |
" layer = model.features[i]\n",
|
| 754 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 755 |
"\n",
|
| 756 |
" with torch.no_grad():\n",
|
| 757 |
+
" _ = model(image_tensor.unsqueeze(0))\n",
|
| 758 |
"\n",
|
| 759 |
" for h in hooks:\n",
|
| 760 |
" h.remove()\n",
|
| 761 |
"\n",
|
| 762 |
" for layer_name, fmap in activations.items():\n",
|
| 763 |
+
" fmap = fmap.squeeze(0)\n",
|
| 764 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 766 |
" top_indices = topk.indices\n",
|
|
|
|
|
|
|
| 767 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 768 |
" for idx, ch in enumerate(top_indices):\n",
|
| 769 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
|
| 804 |
"\n",
|
| 805 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 806 |
"\n",
|
|
|
|
| 807 |
"transform = transforms.Compose([\n",
|
| 808 |
" transforms.Resize((224, 224)),\n",
|
| 809 |
" transforms.ToTensor()\n",
|
| 810 |
"])\n",
|
| 811 |
+
"img_tensor = transform(img)\n",
|
|
|
|
|
|
|
| 812 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 813 |
]
|
| 814 |
},
|
|
|
|
| 829 |
"source": [
|
| 830 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
|
| 831 |
"\n",
|
|
|
|
| 832 |
"transform = transforms.Compose([\n",
|
| 833 |
" transforms.Resize((224, 224)),\n",
|
| 834 |
" transforms.ToTensor()\n",
|
| 835 |
"])\n",
|
| 836 |
+
"img_tensor = transform(img)\n",
|
| 837 |
"\n",
|
|
|
|
| 838 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 839 |
]
|
| 840 |
},
|
|
|
|
| 855 |
"source": [
|
| 856 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 857 |
"\n",
|
|
|
|
| 858 |
"transform = transforms.Compose([\n",
|
| 859 |
" transforms.Resize((224, 224)),\n",
|
| 860 |
" transforms.ToTensor()\n",
|
| 861 |
"])\n",
|
| 862 |
+
"img_tensor = transform(img)\n",
|
| 863 |
"\n",
|
|
|
|
| 864 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 865 |
]
|
| 866 |
},
|
scripts/CV/compression.ipynb
CHANGED
|
@@ -18,8 +18,8 @@
|
|
| 18 |
"import os\n",
|
| 19 |
"from PIL import Image, ImageOps\n",
|
| 20 |
"\n",
|
| 21 |
-
"input_root = 'Tomato'
|
| 22 |
-
"output_root = 'Tomato_512'
|
| 23 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 24 |
"\n",
|
| 25 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
|
@@ -27,14 +27,12 @@
|
|
| 27 |
" with Image.open(input_path) as img:\n",
|
| 28 |
" img = img.convert(\"RGB\")\n",
|
| 29 |
"\n",
|
| 30 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
| 31 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 32 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 33 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 34 |
" except Exception as e:\n",
|
| 35 |
-
" print(f\"
|
| 36 |
"\n",
|
| 37 |
-
"# Recursively walk through input_root\n",
|
| 38 |
"for root, _, files in os.walk(input_root):\n",
|
| 39 |
" for file in files:\n",
|
| 40 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
@@ -43,7 +41,7 @@
|
|
| 43 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 44 |
" process_image(input_path, output_path)\n",
|
| 45 |
"\n",
|
| 46 |
-
"print(\"
|
| 47 |
]
|
| 48 |
},
|
| 49 |
{
|
|
@@ -56,23 +54,20 @@
|
|
| 56 |
"import os\n",
|
| 57 |
"from PIL import Image, ImageOps\n",
|
| 58 |
"\n",
|
| 59 |
-
"input_root = 'Onion'
|
| 60 |
-
"output_root = 'Onion_512'
|
| 61 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 62 |
"\n",
|
| 63 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 64 |
" try:\n",
|
| 65 |
" with Image.open(input_path) as img:\n",
|
| 66 |
" img = img.convert(\"RGB\")\n",
|
| 67 |
-
"\n",
|
| 68 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
| 69 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 70 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 71 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 72 |
" except Exception as e:\n",
|
| 73 |
-
" print(f\"
|
| 74 |
"\n",
|
| 75 |
-
"# Recursively walk through input_root\n",
|
| 76 |
"for root, _, files in os.walk(input_root):\n",
|
| 77 |
" for file in files:\n",
|
| 78 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
@@ -81,7 +76,7 @@
|
|
| 81 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 82 |
" process_image(input_path, output_path)\n",
|
| 83 |
"\n",
|
| 84 |
-
"print(\"
|
| 85 |
]
|
| 86 |
},
|
| 87 |
{
|
|
@@ -94,23 +89,20 @@
|
|
| 94 |
"import os\n",
|
| 95 |
"from PIL import Image, ImageOps\n",
|
| 96 |
"\n",
|
| 97 |
-
"input_root = 'Pear'
|
| 98 |
-
"output_root = 'Pear_512'
|
| 99 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 100 |
"\n",
|
| 101 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 102 |
" try:\n",
|
| 103 |
" with Image.open(input_path) as img:\n",
|
| 104 |
" img = img.convert(\"RGB\")\n",
|
| 105 |
-
"\n",
|
| 106 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
| 107 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 108 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 109 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 110 |
" except Exception as e:\n",
|
| 111 |
-
" print(f\"
|
| 112 |
"\n",
|
| 113 |
-
"# Recursively walk through input_root\n",
|
| 114 |
"for root, _, files in os.walk(input_root):\n",
|
| 115 |
" for file in files:\n",
|
| 116 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
@@ -119,7 +111,7 @@
|
|
| 119 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 120 |
" process_image(input_path, output_path)\n",
|
| 121 |
"\n",
|
| 122 |
-
"print(\"
|
| 123 |
]
|
| 124 |
},
|
| 125 |
{
|
|
@@ -132,23 +124,21 @@
|
|
| 132 |
"import os\n",
|
| 133 |
"from PIL import Image, ImageOps\n",
|
| 134 |
"\n",
|
| 135 |
-
"input_root = 'Strawberry'
|
| 136 |
-
"output_root = 'Strawberry_512'
|
| 137 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 138 |
"\n",
|
| 139 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 140 |
" try:\n",
|
| 141 |
" with Image.open(input_path) as img:\n",
|
| 142 |
" img = img.convert(\"RGB\")\n",
|
| 143 |
-
"\n",
|
| 144 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
| 145 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 146 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 147 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 148 |
" except Exception as e:\n",
|
| 149 |
-
" print(f\"
|
|
|
|
| 150 |
"\n",
|
| 151 |
-
"# Recursively walk through input_root\n",
|
| 152 |
"for root, _, files in os.walk(input_root):\n",
|
| 153 |
" for file in files:\n",
|
| 154 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
@@ -157,16 +147,8 @@
|
|
| 157 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 158 |
" process_image(input_path, output_path)\n",
|
| 159 |
"\n",
|
| 160 |
-
"print(\"
|
| 161 |
]
|
| 162 |
-
},
|
| 163 |
-
{
|
| 164 |
-
"cell_type": "code",
|
| 165 |
-
"execution_count": null,
|
| 166 |
-
"id": "fd49ae48",
|
| 167 |
-
"metadata": {},
|
| 168 |
-
"outputs": [],
|
| 169 |
-
"source": []
|
| 170 |
}
|
| 171 |
],
|
| 172 |
"metadata": {
|
|
|
|
| 18 |
"import os\n",
|
| 19 |
"from PIL import Image, ImageOps\n",
|
| 20 |
"\n",
|
| 21 |
+
"input_root = 'Tomato' \n",
|
| 22 |
+
"output_root = 'Tomato_512' \n",
|
| 23 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 24 |
"\n",
|
| 25 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
|
|
|
| 27 |
" with Image.open(input_path) as img:\n",
|
| 28 |
" img = img.convert(\"RGB\")\n",
|
| 29 |
"\n",
|
|
|
|
| 30 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 31 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 32 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 33 |
" except Exception as e:\n",
|
| 34 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
| 35 |
"\n",
|
|
|
|
| 36 |
"for root, _, files in os.walk(input_root):\n",
|
| 37 |
" for file in files:\n",
|
| 38 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
|
| 41 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 42 |
" process_image(input_path, output_path)\n",
|
| 43 |
"\n",
|
| 44 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
| 45 |
]
|
| 46 |
},
|
| 47 |
{
|
|
|
|
| 54 |
"import os\n",
|
| 55 |
"from PIL import Image, ImageOps\n",
|
| 56 |
"\n",
|
| 57 |
+
"input_root = 'Onion' \n",
|
| 58 |
+
"output_root = 'Onion_512' \n",
|
| 59 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 60 |
"\n",
|
| 61 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 62 |
" try:\n",
|
| 63 |
" with Image.open(input_path) as img:\n",
|
| 64 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
|
|
| 65 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 66 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 67 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 68 |
" except Exception as e:\n",
|
| 69 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
| 70 |
"\n",
|
|
|
|
| 71 |
"for root, _, files in os.walk(input_root):\n",
|
| 72 |
" for file in files:\n",
|
| 73 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
|
| 76 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 77 |
" process_image(input_path, output_path)\n",
|
| 78 |
"\n",
|
| 79 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
| 80 |
]
|
| 81 |
},
|
| 82 |
{
|
|
|
|
| 89 |
"import os\n",
|
| 90 |
"from PIL import Image, ImageOps\n",
|
| 91 |
"\n",
|
| 92 |
+
"input_root = 'Pear' \n",
|
| 93 |
+
"output_root = 'Pear_512' \n",
|
| 94 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 95 |
"\n",
|
| 96 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 97 |
" try:\n",
|
| 98 |
" with Image.open(input_path) as img:\n",
|
| 99 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
|
|
| 100 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 101 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 102 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 103 |
" except Exception as e:\n",
|
| 104 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
| 105 |
"\n",
|
|
|
|
| 106 |
"for root, _, files in os.walk(input_root):\n",
|
| 107 |
" for file in files:\n",
|
| 108 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
|
| 111 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 112 |
" process_image(input_path, output_path)\n",
|
| 113 |
"\n",
|
| 114 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
| 115 |
]
|
| 116 |
},
|
| 117 |
{
|
|
|
|
| 124 |
"import os\n",
|
| 125 |
"from PIL import Image, ImageOps\n",
|
| 126 |
"\n",
|
| 127 |
+
"input_root = 'Strawberry' \n",
|
| 128 |
+
"output_root = 'Strawberry_512' \n",
|
| 129 |
"os.makedirs(output_root, exist_ok=True)\n",
|
| 130 |
"\n",
|
| 131 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
| 132 |
" try:\n",
|
| 133 |
" with Image.open(input_path) as img:\n",
|
| 134 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
|
|
| 135 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
| 136 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
| 137 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
| 138 |
" except Exception as e:\n",
|
| 139 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
| 140 |
+
"\n",
|
| 141 |
"\n",
|
|
|
|
| 142 |
"for root, _, files in os.walk(input_root):\n",
|
| 143 |
" for file in files:\n",
|
| 144 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
|
| 147 |
" output_path = os.path.join(output_root, rel_path)\n",
|
| 148 |
" process_image(input_path, output_path)\n",
|
| 149 |
"\n",
|
| 150 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
| 151 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
}
|
| 153 |
],
|
| 154 |
"metadata": {
|
scripts/CV/script_onion.ipynb
CHANGED
|
@@ -59,10 +59,10 @@
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
-
" for k in [1, 2, 3]
|
| 63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
-
" y_aug.append(y.clone())
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
@@ -120,7 +120,6 @@
|
|
| 120 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 121 |
" plt.show()\n",
|
| 122 |
"\n",
|
| 123 |
-
"# Display for each class\n",
|
| 124 |
"for class_name, image_array in datasets.items():\n",
|
| 125 |
" show_random_samples(image_array, class_name)\n"
|
| 126 |
]
|
|
@@ -136,7 +135,7 @@
|
|
| 136 |
"\n",
|
| 137 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 138 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 139 |
-
" ax.label_outer()
|
| 140 |
"\n",
|
| 141 |
"plt.tight_layout()\n",
|
| 142 |
"plt.show()"
|
|
@@ -152,7 +151,7 @@
|
|
| 152 |
"class_names = list(datasets.keys())\n",
|
| 153 |
"num_classes = len(class_names)\n",
|
| 154 |
"\n",
|
| 155 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
| 156 |
"\n",
|
| 157 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 158 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
@@ -209,7 +208,6 @@
|
|
| 209 |
"\n",
|
| 210 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 211 |
"\n",
|
| 212 |
-
"# Combine original and augmented data\n",
|
| 213 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 214 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 215 |
"\n",
|
|
@@ -230,9 +228,9 @@
|
|
| 230 |
"metadata": {},
|
| 231 |
"outputs": [],
|
| 232 |
"source": [
|
| 233 |
-
"print(f\"
|
| 234 |
-
"print(f\"
|
| 235 |
-
"print(f\"
|
| 236 |
]
|
| 237 |
},
|
| 238 |
{
|
|
@@ -480,29 +478,24 @@
|
|
| 480 |
" activations[name] = output.detach().cpu()\n",
|
| 481 |
" return hook\n",
|
| 482 |
"\n",
|
| 483 |
-
" # Register hooks for all layers in model.features\n",
|
| 484 |
" hooks = []\n",
|
| 485 |
" for i in range(len(model.features)):\n",
|
| 486 |
" layer = model.features[i]\n",
|
| 487 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 488 |
"\n",
|
| 489 |
" with torch.no_grad():\n",
|
| 490 |
-
" _ = model(image_tensor.unsqueeze(0))
|
| 491 |
"\n",
|
| 492 |
" for h in hooks:\n",
|
| 493 |
" h.remove()\n",
|
| 494 |
"\n",
|
| 495 |
" for layer_name, fmap in activations.items():\n",
|
| 496 |
-
" fmap = fmap.squeeze(0)
|
|
|
|
| 497 |
"\n",
|
| 498 |
-
" # Compute mean activation per channel\n",
|
| 499 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
| 500 |
-
"\n",
|
| 501 |
-
" # Get indices of top-k channels\n",
|
| 502 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 503 |
" top_indices = topk.indices\n",
|
| 504 |
"\n",
|
| 505 |
-
" # Plot top-k channels\n",
|
| 506 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 507 |
" for idx, ch in enumerate(top_indices):\n",
|
| 508 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
@@ -535,14 +528,12 @@
|
|
| 535 |
"\n",
|
| 536 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 537 |
"\n",
|
| 538 |
-
"# Preprocessing (must match model requirements)\n",
|
| 539 |
"transform = transforms.Compose([\n",
|
| 540 |
" transforms.Resize((224, 224)),\n",
|
| 541 |
" transforms.ToTensor()\n",
|
| 542 |
"])\n",
|
| 543 |
-
"img_tensor = transform(img)
|
| 544 |
"\n",
|
| 545 |
-
"# Visualize feature maps\n",
|
| 546 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 547 |
]
|
| 548 |
},
|
|
@@ -556,14 +547,11 @@
|
|
| 556 |
"\n",
|
| 557 |
"img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
|
| 558 |
"\n",
|
| 559 |
-
"# Preprocessing (must match model requirements)\n",
|
| 560 |
"transform = transforms.Compose([\n",
|
| 561 |
" transforms.Resize((224, 224)),\n",
|
| 562 |
" transforms.ToTensor()\n",
|
| 563 |
"])\n",
|
| 564 |
-
"img_tensor = transform(img)
|
| 565 |
-
"\n",
|
| 566 |
-
"# Visualize feature maps\n",
|
| 567 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 568 |
]
|
| 569 |
},
|
|
@@ -576,14 +564,12 @@
|
|
| 576 |
"source": [
|
| 577 |
"img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
|
| 578 |
"\n",
|
| 579 |
-
"# Preprocessing (must match model requirements)\n",
|
| 580 |
"transform = transforms.Compose([\n",
|
| 581 |
" transforms.Resize((224, 224)),\n",
|
| 582 |
" transforms.ToTensor()\n",
|
| 583 |
"])\n",
|
| 584 |
-
"img_tensor = transform(img)
|
| 585 |
"\n",
|
| 586 |
-
"# Visualize feature maps\n",
|
| 587 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 588 |
]
|
| 589 |
},
|
|
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
+
" for k in [1, 2, 3]:\n",
|
| 63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
+
" y_aug.append(y.clone()) \n",
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
|
|
| 120 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 121 |
" plt.show()\n",
|
| 122 |
"\n",
|
|
|
|
| 123 |
"for class_name, image_array in datasets.items():\n",
|
| 124 |
" show_random_samples(image_array, class_name)\n"
|
| 125 |
]
|
|
|
|
| 135 |
"\n",
|
| 136 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 137 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 138 |
+
" ax.label_outer()\n",
|
| 139 |
"\n",
|
| 140 |
"plt.tight_layout()\n",
|
| 141 |
"plt.show()"
|
|
|
|
| 151 |
"class_names = list(datasets.keys())\n",
|
| 152 |
"num_classes = len(class_names)\n",
|
| 153 |
"\n",
|
| 154 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
| 155 |
"\n",
|
| 156 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 157 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
|
| 208 |
"\n",
|
| 209 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 210 |
"\n",
|
|
|
|
| 211 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 212 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 213 |
"\n",
|
|
|
|
| 228 |
"metadata": {},
|
| 229 |
"outputs": [],
|
| 230 |
"source": [
|
| 231 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
| 232 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
| 233 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
| 234 |
]
|
| 235 |
},
|
| 236 |
{
|
|
|
|
| 478 |
" activations[name] = output.detach().cpu()\n",
|
| 479 |
" return hook\n",
|
| 480 |
"\n",
|
|
|
|
| 481 |
" hooks = []\n",
|
| 482 |
" for i in range(len(model.features)):\n",
|
| 483 |
" layer = model.features[i]\n",
|
| 484 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 485 |
"\n",
|
| 486 |
" with torch.no_grad():\n",
|
| 487 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
| 488 |
"\n",
|
| 489 |
" for h in hooks:\n",
|
| 490 |
" h.remove()\n",
|
| 491 |
"\n",
|
| 492 |
" for layer_name, fmap in activations.items():\n",
|
| 493 |
+
" fmap = fmap.squeeze(0) \n",
|
| 494 |
+
" channel_scores = fmap.mean(dim=(1, 2))\n",
|
| 495 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 497 |
" top_indices = topk.indices\n",
|
| 498 |
"\n",
|
|
|
|
| 499 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 500 |
" for idx, ch in enumerate(top_indices):\n",
|
| 501 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
|
| 528 |
"\n",
|
| 529 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
| 530 |
"\n",
|
|
|
|
| 531 |
"transform = transforms.Compose([\n",
|
| 532 |
" transforms.Resize((224, 224)),\n",
|
| 533 |
" transforms.ToTensor()\n",
|
| 534 |
"])\n",
|
| 535 |
+
"img_tensor = transform(img)\n",
|
| 536 |
"\n",
|
|
|
|
| 537 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 538 |
]
|
| 539 |
},
|
|
|
|
| 547 |
"\n",
|
| 548 |
"img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
|
| 549 |
"\n",
|
|
|
|
| 550 |
"transform = transforms.Compose([\n",
|
| 551 |
" transforms.Resize((224, 224)),\n",
|
| 552 |
" transforms.ToTensor()\n",
|
| 553 |
"])\n",
|
| 554 |
+
"img_tensor = transform(img)\n",
|
|
|
|
|
|
|
| 555 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 556 |
]
|
| 557 |
},
|
|
|
|
| 564 |
"source": [
|
| 565 |
"img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
|
| 566 |
"\n",
|
|
|
|
| 567 |
"transform = transforms.Compose([\n",
|
| 568 |
" transforms.Resize((224, 224)),\n",
|
| 569 |
" transforms.ToTensor()\n",
|
| 570 |
"])\n",
|
| 571 |
+
"img_tensor = transform(img) \n",
|
| 572 |
"\n",
|
|
|
|
| 573 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 574 |
]
|
| 575 |
},
|
scripts/CV/script_pear.ipynb
CHANGED
|
@@ -59,10 +59,10 @@
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
-
" for k in [1, 2, 3]:
|
| 63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
-
" y_aug.append(y.clone())
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
@@ -122,7 +122,6 @@
|
|
| 122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 123 |
" plt.show()\n",
|
| 124 |
"\n",
|
| 125 |
-
"# Display for each class\n",
|
| 126 |
"for class_name, image_array in datasets.items():\n",
|
| 127 |
" show_random_samples(image_array, class_name)\n"
|
| 128 |
]
|
|
@@ -138,7 +137,7 @@
|
|
| 138 |
"\n",
|
| 139 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 140 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 141 |
-
" ax.label_outer()
|
| 142 |
"\n",
|
| 143 |
"plt.tight_layout()\n",
|
| 144 |
"plt.show()"
|
|
@@ -154,7 +153,7 @@
|
|
| 154 |
"class_names = list(datasets.keys())\n",
|
| 155 |
"num_classes = len(class_names)\n",
|
| 156 |
"\n",
|
| 157 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
| 158 |
"\n",
|
| 159 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 160 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
@@ -180,7 +179,6 @@
|
|
| 180 |
" \"whole\": pear_whole_images\n",
|
| 181 |
"}\n",
|
| 182 |
"\n",
|
| 183 |
-
"# Combine data\n",
|
| 184 |
"X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
|
| 185 |
"y = (\n",
|
| 186 |
" ['halved'] * len(pear_halved_images) +\n",
|
|
@@ -188,17 +186,14 @@
|
|
| 188 |
" ['whole'] * len(pear_whole_images)\n",
|
| 189 |
")\n",
|
| 190 |
"\n",
|
| 191 |
-
"# Normalize and convert to torch tensors\n",
|
| 192 |
"X = X.astype(np.float32) / 255.0\n",
|
| 193 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
| 194 |
"X_tensor = torch.tensor(X)\n",
|
| 195 |
"\n",
|
| 196 |
-
"# Encode labels\n",
|
| 197 |
"le = LabelEncoder()\n",
|
| 198 |
"y_encoded = le.fit_transform(y)\n",
|
| 199 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 200 |
"\n",
|
| 201 |
-
"# Train/val/test split\n",
|
| 202 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
| 203 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 204 |
]
|
|
@@ -215,7 +210,6 @@
|
|
| 215 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 216 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 217 |
"\n",
|
| 218 |
-
"# DataLoaders\n",
|
| 219 |
"batch_size = 32\n",
|
| 220 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 221 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
@@ -229,9 +223,9 @@
|
|
| 229 |
"metadata": {},
|
| 230 |
"outputs": [],
|
| 231 |
"source": [
|
| 232 |
-
"print(f\"
|
| 233 |
-
"print(f\"
|
| 234 |
-
"print(f\"
|
| 235 |
]
|
| 236 |
},
|
| 237 |
{
|
|
@@ -249,8 +243,6 @@
|
|
| 249 |
"\n",
|
| 250 |
"def get_efficientnet_model(num_classes):\n",
|
| 251 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
| 252 |
-
"\n",
|
| 253 |
-
" # Replace classifier head with custom head\n",
|
| 254 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
| 255 |
"\n",
|
| 256 |
" return model\n",
|
|
@@ -266,10 +258,10 @@
|
|
| 266 |
"source": [
|
| 267 |
"if torch.backends.mps.is_available():\n",
|
| 268 |
" device = torch.device(\"mps\")\n",
|
| 269 |
-
" print(\"
|
| 270 |
"else:\n",
|
| 271 |
" device = torch.device(\"cpu\")\n",
|
| 272 |
-
" print(\"
|
| 273 |
"\n",
|
| 274 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
| 275 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
@@ -312,7 +304,6 @@
|
|
| 312 |
"\n",
|
| 313 |
" total_train_loss += loss.item()\n",
|
| 314 |
"\n",
|
| 315 |
-
" # Track training accuracy\n",
|
| 316 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 317 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 318 |
" train_total += batch_y.size(0)\n",
|
|
@@ -367,7 +358,6 @@
|
|
| 367 |
"\n",
|
| 368 |
"plt.figure(figsize=(12, 5))\n",
|
| 369 |
"\n",
|
| 370 |
-
"# Plot Loss\n",
|
| 371 |
"plt.subplot(1, 2, 1)\n",
|
| 372 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
| 373 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
@@ -377,7 +367,6 @@
|
|
| 377 |
"plt.legend()\n",
|
| 378 |
"plt.grid(True)\n",
|
| 379 |
"\n",
|
| 380 |
-
"# Plot Accuracy\n",
|
| 381 |
"plt.subplot(1, 2, 2)\n",
|
| 382 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
| 383 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
@@ -490,29 +479,24 @@
|
|
| 490 |
" activations[name] = output.detach().cpu()\n",
|
| 491 |
" return hook\n",
|
| 492 |
"\n",
|
| 493 |
-
" # Register hooks for all layers in model.features\n",
|
| 494 |
" hooks = []\n",
|
| 495 |
" for i in range(len(model.features)):\n",
|
| 496 |
" layer = model.features[i]\n",
|
| 497 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 498 |
"\n",
|
| 499 |
" with torch.no_grad():\n",
|
| 500 |
-
" _ = model(image_tensor.unsqueeze(0))
|
| 501 |
"\n",
|
| 502 |
" for h in hooks:\n",
|
| 503 |
" h.remove()\n",
|
| 504 |
"\n",
|
| 505 |
" for layer_name, fmap in activations.items():\n",
|
| 506 |
-
" fmap = fmap.squeeze(0)
|
| 507 |
-
"\n",
|
| 508 |
-
" # Compute mean activation per channel\n",
|
| 509 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
| 510 |
"\n",
|
| 511 |
-
" # Get indices of top-k channels\n",
|
| 512 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 513 |
" top_indices = topk.indices\n",
|
| 514 |
"\n",
|
| 515 |
-
" # Plot top-k channels\n",
|
| 516 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 517 |
" for idx, ch in enumerate(top_indices):\n",
|
| 518 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
@@ -545,14 +529,11 @@
|
|
| 545 |
"\n",
|
| 546 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
| 547 |
"\n",
|
| 548 |
-
"# Preprocessing (must match model requirements)\n",
|
| 549 |
"transform = transforms.Compose([\n",
|
| 550 |
" transforms.Resize((224, 224)),\n",
|
| 551 |
" transforms.ToTensor()\n",
|
| 552 |
"])\n",
|
| 553 |
-
"img_tensor = transform(img)
|
| 554 |
-
"\n",
|
| 555 |
-
"# Visualize feature maps\n",
|
| 556 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 557 |
]
|
| 558 |
},
|
|
@@ -565,14 +546,11 @@
|
|
| 565 |
"source": [
|
| 566 |
"img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
|
| 567 |
"\n",
|
| 568 |
-
"# Preprocessing (must match model requirements)\n",
|
| 569 |
"transform = transforms.Compose([\n",
|
| 570 |
" transforms.Resize((224, 224)),\n",
|
| 571 |
" transforms.ToTensor()\n",
|
| 572 |
"])\n",
|
| 573 |
-
"img_tensor = transform(img)
|
| 574 |
-
"\n",
|
| 575 |
-
"# Visualize feature maps\n",
|
| 576 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 577 |
]
|
| 578 |
},
|
|
@@ -584,15 +562,11 @@
|
|
| 584 |
"outputs": [],
|
| 585 |
"source": [
|
| 586 |
"img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
|
| 587 |
-
"\n",
|
| 588 |
-
"# Preprocessing (must match model requirements)\n",
|
| 589 |
"transform = transforms.Compose([\n",
|
| 590 |
" transforms.Resize((224, 224)),\n",
|
| 591 |
" transforms.ToTensor()\n",
|
| 592 |
"])\n",
|
| 593 |
-
"img_tensor = transform(img)
|
| 594 |
-
"\n",
|
| 595 |
-
"# Visualize feature maps\n",
|
| 596 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 597 |
]
|
| 598 |
},
|
|
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
+
" for k in [1, 2, 3]: \n",
|
| 63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
+
" y_aug.append(y.clone()) \n",
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
|
|
| 122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 123 |
" plt.show()\n",
|
| 124 |
"\n",
|
|
|
|
| 125 |
"for class_name, image_array in datasets.items():\n",
|
| 126 |
" show_random_samples(image_array, class_name)\n"
|
| 127 |
]
|
|
|
|
| 137 |
"\n",
|
| 138 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 139 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 140 |
+
" ax.label_outer() \n",
|
| 141 |
"\n",
|
| 142 |
"plt.tight_layout()\n",
|
| 143 |
"plt.show()"
|
|
|
|
| 153 |
"class_names = list(datasets.keys())\n",
|
| 154 |
"num_classes = len(class_names)\n",
|
| 155 |
"\n",
|
| 156 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
| 157 |
"\n",
|
| 158 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 159 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
|
| 179 |
" \"whole\": pear_whole_images\n",
|
| 180 |
"}\n",
|
| 181 |
"\n",
|
|
|
|
| 182 |
"X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
|
| 183 |
"y = (\n",
|
| 184 |
" ['halved'] * len(pear_halved_images) +\n",
|
|
|
|
| 186 |
" ['whole'] * len(pear_whole_images)\n",
|
| 187 |
")\n",
|
| 188 |
"\n",
|
|
|
|
| 189 |
"X = X.astype(np.float32) / 255.0\n",
|
| 190 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
| 191 |
"X_tensor = torch.tensor(X)\n",
|
| 192 |
"\n",
|
|
|
|
| 193 |
"le = LabelEncoder()\n",
|
| 194 |
"y_encoded = le.fit_transform(y)\n",
|
| 195 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 196 |
"\n",
|
|
|
|
| 197 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
| 198 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 199 |
]
|
|
|
|
| 210 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 211 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 212 |
"\n",
|
|
|
|
| 213 |
"batch_size = 32\n",
|
| 214 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 215 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
|
|
| 223 |
"metadata": {},
|
| 224 |
"outputs": [],
|
| 225 |
"source": [
|
| 226 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
| 227 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
| 228 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
| 229 |
]
|
| 230 |
},
|
| 231 |
{
|
|
|
|
| 243 |
"\n",
|
| 244 |
"def get_efficientnet_model(num_classes):\n",
|
| 245 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
|
|
|
|
|
|
| 246 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
| 247 |
"\n",
|
| 248 |
" return model\n",
|
|
|
|
| 258 |
"source": [
|
| 259 |
"if torch.backends.mps.is_available():\n",
|
| 260 |
" device = torch.device(\"mps\")\n",
|
| 261 |
+
" print(\"Using MPS (Apple GPU)\")\n",
|
| 262 |
"else:\n",
|
| 263 |
" device = torch.device(\"cpu\")\n",
|
| 264 |
+
" print(\"MPS not available. Using CPU\")\n",
|
| 265 |
"\n",
|
| 266 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
| 267 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
|
|
| 304 |
"\n",
|
| 305 |
" total_train_loss += loss.item()\n",
|
| 306 |
"\n",
|
|
|
|
| 307 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 308 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 309 |
" train_total += batch_y.size(0)\n",
|
|
|
|
| 358 |
"\n",
|
| 359 |
"plt.figure(figsize=(12, 5))\n",
|
| 360 |
"\n",
|
|
|
|
| 361 |
"plt.subplot(1, 2, 1)\n",
|
| 362 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
| 363 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
|
|
| 367 |
"plt.legend()\n",
|
| 368 |
"plt.grid(True)\n",
|
| 369 |
"\n",
|
|
|
|
| 370 |
"plt.subplot(1, 2, 2)\n",
|
| 371 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
| 372 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
|
|
| 479 |
" activations[name] = output.detach().cpu()\n",
|
| 480 |
" return hook\n",
|
| 481 |
"\n",
|
|
|
|
| 482 |
" hooks = []\n",
|
| 483 |
" for i in range(len(model.features)):\n",
|
| 484 |
" layer = model.features[i]\n",
|
| 485 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 486 |
"\n",
|
| 487 |
" with torch.no_grad():\n",
|
| 488 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
| 489 |
"\n",
|
| 490 |
" for h in hooks:\n",
|
| 491 |
" h.remove()\n",
|
| 492 |
"\n",
|
| 493 |
" for layer_name, fmap in activations.items():\n",
|
| 494 |
+
" fmap = fmap.squeeze(0) \n",
|
| 495 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
|
|
|
|
| 496 |
"\n",
|
|
|
|
| 497 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 498 |
" top_indices = topk.indices\n",
|
| 499 |
"\n",
|
|
|
|
| 500 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 501 |
" for idx, ch in enumerate(top_indices):\n",
|
| 502 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
|
| 529 |
"\n",
|
| 530 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
| 531 |
"\n",
|
|
|
|
| 532 |
"transform = transforms.Compose([\n",
|
| 533 |
" transforms.Resize((224, 224)),\n",
|
| 534 |
" transforms.ToTensor()\n",
|
| 535 |
"])\n",
|
| 536 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
|
|
| 537 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 538 |
]
|
| 539 |
},
|
|
|
|
| 546 |
"source": [
|
| 547 |
"img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
|
| 548 |
"\n",
|
|
|
|
| 549 |
"transform = transforms.Compose([\n",
|
| 550 |
" transforms.Resize((224, 224)),\n",
|
| 551 |
" transforms.ToTensor()\n",
|
| 552 |
"])\n",
|
| 553 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
|
|
| 554 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 555 |
]
|
| 556 |
},
|
|
|
|
| 562 |
"outputs": [],
|
| 563 |
"source": [
|
| 564 |
"img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
|
|
|
|
|
|
|
| 565 |
"transform = transforms.Compose([\n",
|
| 566 |
" transforms.Resize((224, 224)),\n",
|
| 567 |
" transforms.ToTensor()\n",
|
| 568 |
"])\n",
|
| 569 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
|
|
| 570 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 571 |
]
|
| 572 |
},
|
scripts/CV/script_strawberry.ipynb
CHANGED
|
@@ -59,10 +59,10 @@
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
-
" for k in [1, 2, 3]:
|
| 63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
-
" y_aug.append(y.clone())
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
@@ -124,7 +124,6 @@
|
|
| 124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 125 |
" plt.show()\n",
|
| 126 |
"\n",
|
| 127 |
-
"# Display for each class\n",
|
| 128 |
"for class_name, image_array in datasets.items():\n",
|
| 129 |
" show_random_samples(image_array, class_name)\n"
|
| 130 |
]
|
|
@@ -140,7 +139,7 @@
|
|
| 140 |
"\n",
|
| 141 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 142 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 143 |
-
" ax.label_outer()
|
| 144 |
"\n",
|
| 145 |
"plt.tight_layout()\n",
|
| 146 |
"plt.show()"
|
|
@@ -156,7 +155,7 @@
|
|
| 156 |
"class_names = list(datasets.keys())\n",
|
| 157 |
"num_classes = len(class_names)\n",
|
| 158 |
"\n",
|
| 159 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
| 160 |
"\n",
|
| 161 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 162 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
@@ -181,7 +180,6 @@
|
|
| 181 |
" \"whole\": strawberry_whole_images\n",
|
| 182 |
"}\n",
|
| 183 |
"\n",
|
| 184 |
-
"# Combine data\n",
|
| 185 |
"X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
|
| 186 |
"y = (\n",
|
| 187 |
" ['hulled'] * len(strawberry_hulled_images) +\n",
|
|
@@ -189,17 +187,14 @@
|
|
| 189 |
" ['whole'] * len(strawberry_whole_images)\n",
|
| 190 |
")\n",
|
| 191 |
"\n",
|
| 192 |
-
"# Normalize and convert to torch tensors\n",
|
| 193 |
"X = X.astype(np.float32) / 255.0\n",
|
| 194 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
| 195 |
"X_tensor = torch.tensor(X)\n",
|
| 196 |
"\n",
|
| 197 |
-
"# Encode labels\n",
|
| 198 |
"le = LabelEncoder()\n",
|
| 199 |
"y_encoded = le.fit_transform(y)\n",
|
| 200 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 201 |
"\n",
|
| 202 |
-
"# Train/val/test split\n",
|
| 203 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
| 204 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 205 |
]
|
|
@@ -215,11 +210,9 @@
|
|
| 215 |
"\n",
|
| 216 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 217 |
"\n",
|
| 218 |
-
"# Combine original and augmented data\n",
|
| 219 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 220 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 221 |
"\n",
|
| 222 |
-
"\n",
|
| 223 |
"train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
|
| 224 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 225 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
|
@@ -236,9 +229,9 @@
|
|
| 236 |
"metadata": {},
|
| 237 |
"outputs": [],
|
| 238 |
"source": [
|
| 239 |
-
"print(f\"
|
| 240 |
-
"print(f\"
|
| 241 |
-
"print(f\"
|
| 242 |
]
|
| 243 |
},
|
| 244 |
{
|
|
@@ -341,7 +334,6 @@
|
|
| 341 |
" val_accuracy = val_correct / val_total\n",
|
| 342 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 343 |
"\n",
|
| 344 |
-
" # After calculating val_accuracy\n",
|
| 345 |
" val_losses.append(validation_loss)\n",
|
| 346 |
" val_accs.append(val_accuracy)\n",
|
| 347 |
"\n",
|
|
@@ -449,29 +441,25 @@
|
|
| 449 |
"source": [
|
| 450 |
"all_preds = np.array(all_preds)\n",
|
| 451 |
"all_targets = np.array(all_targets)\n",
|
| 452 |
-
"all_images = torch.stack(all_images)
|
| 453 |
"\n",
|
| 454 |
-
"# Per class FP and FN\n",
|
| 455 |
"for class_idx, class_name in enumerate(target_names):\n",
|
| 456 |
-
" print(f\"\\
|
| 457 |
-
"\n",
|
| 458 |
-
" # False Negatives: True label is class_idx, but predicted something else\n",
|
| 459 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
| 460 |
-
" # False Positives: Predicted class_idx, but true label is different\n",
|
| 461 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
| 462 |
"\n",
|
| 463 |
" def show_images(indices, title, max_images=5):\n",
|
| 464 |
" num = min(len(indices), max_images)\n",
|
| 465 |
" if num == 0:\n",
|
| 466 |
-
" print(f\"
|
| 467 |
" return\n",
|
| 468 |
"\n",
|
| 469 |
" plt.figure(figsize=(12, 2))\n",
|
| 470 |
" for i, idx in enumerate(indices[:num]):\n",
|
| 471 |
" img = all_images[idx]\n",
|
| 472 |
-
" img = img.permute(1, 2, 0).numpy()
|
| 473 |
" plt.subplot(1, num, i + 1)\n",
|
| 474 |
-
" plt.imshow((img - img.min()) / (img.max() - img.min()))
|
| 475 |
" plt.axis('off')\n",
|
| 476 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
| 477 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
@@ -498,29 +486,25 @@
|
|
| 498 |
" activations[name] = output.detach().cpu()\n",
|
| 499 |
" return hook\n",
|
| 500 |
"\n",
|
| 501 |
-
" # Register hooks for all layers in model.features\n",
|
| 502 |
" hooks = []\n",
|
| 503 |
" for i in range(len(model.features)):\n",
|
| 504 |
" layer = model.features[i]\n",
|
| 505 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 506 |
"\n",
|
| 507 |
" with torch.no_grad():\n",
|
| 508 |
-
" _ = model(image_tensor.unsqueeze(0))
|
| 509 |
"\n",
|
| 510 |
" for h in hooks:\n",
|
| 511 |
" h.remove()\n",
|
| 512 |
"\n",
|
| 513 |
" for layer_name, fmap in activations.items():\n",
|
| 514 |
-
" fmap = fmap.squeeze(0)
|
| 515 |
"\n",
|
| 516 |
-
"
|
| 517 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
| 518 |
"\n",
|
| 519 |
-
" # Get indices of top-k channels\n",
|
| 520 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 521 |
" top_indices = topk.indices\n",
|
| 522 |
"\n",
|
| 523 |
-
" # Plot top-k channels\n",
|
| 524 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 525 |
" for idx, ch in enumerate(top_indices):\n",
|
| 526 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
@@ -553,14 +537,12 @@
|
|
| 553 |
"\n",
|
| 554 |
"img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
|
| 555 |
"\n",
|
| 556 |
-
"# Preprocessing (must match model requirements)\n",
|
| 557 |
"transform = transforms.Compose([\n",
|
| 558 |
" transforms.Resize((224, 224)),\n",
|
| 559 |
" transforms.ToTensor()\n",
|
| 560 |
"])\n",
|
| 561 |
-
"img_tensor = transform(img)
|
| 562 |
"\n",
|
| 563 |
-
"# Visualize feature maps\n",
|
| 564 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 565 |
]
|
| 566 |
},
|
|
@@ -574,14 +556,12 @@
|
|
| 574 |
"\n",
|
| 575 |
"img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
|
| 576 |
"\n",
|
| 577 |
-
"# Preprocessing (must match model requirements)\n",
|
| 578 |
"transform = transforms.Compose([\n",
|
| 579 |
" transforms.Resize((224, 224)),\n",
|
| 580 |
" transforms.ToTensor()\n",
|
| 581 |
"])\n",
|
| 582 |
-
"img_tensor = transform(img)
|
| 583 |
"\n",
|
| 584 |
-
"# Visualize feature maps\n",
|
| 585 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 586 |
]
|
| 587 |
},
|
|
@@ -595,15 +575,13 @@
|
|
| 595 |
"\n",
|
| 596 |
"img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
|
| 597 |
"\n",
|
| 598 |
-
"# Preprocessing (must match model requirements)\n",
|
| 599 |
"transform = transforms.Compose([\n",
|
| 600 |
" transforms.Resize((224, 224)),\n",
|
| 601 |
" transforms.ToTensor()\n",
|
| 602 |
"])\n",
|
| 603 |
-
"img_tensor = transform(img)
|
| 604 |
"\n",
|
| 605 |
-
"
|
| 606 |
-
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 607 |
]
|
| 608 |
},
|
| 609 |
{
|
|
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
+
" for k in [1, 2, 3]: \n",
|
| 63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
+
" y_aug.append(y.clone()) \n",
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
|
|
| 124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 125 |
" plt.show()\n",
|
| 126 |
"\n",
|
|
|
|
| 127 |
"for class_name, image_array in datasets.items():\n",
|
| 128 |
" show_random_samples(image_array, class_name)\n"
|
| 129 |
]
|
|
|
|
| 139 |
"\n",
|
| 140 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 141 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 142 |
+
" ax.label_outer() \n",
|
| 143 |
"\n",
|
| 144 |
"plt.tight_layout()\n",
|
| 145 |
"plt.show()"
|
|
|
|
| 155 |
"class_names = list(datasets.keys())\n",
|
| 156 |
"num_classes = len(class_names)\n",
|
| 157 |
"\n",
|
| 158 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
| 159 |
"\n",
|
| 160 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 161 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
|
| 180 |
" \"whole\": strawberry_whole_images\n",
|
| 181 |
"}\n",
|
| 182 |
"\n",
|
|
|
|
| 183 |
"X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
|
| 184 |
"y = (\n",
|
| 185 |
" ['hulled'] * len(strawberry_hulled_images) +\n",
|
|
|
|
| 187 |
" ['whole'] * len(strawberry_whole_images)\n",
|
| 188 |
")\n",
|
| 189 |
"\n",
|
|
|
|
| 190 |
"X = X.astype(np.float32) / 255.0\n",
|
| 191 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
| 192 |
"X_tensor = torch.tensor(X)\n",
|
| 193 |
"\n",
|
|
|
|
| 194 |
"le = LabelEncoder()\n",
|
| 195 |
"y_encoded = le.fit_transform(y)\n",
|
| 196 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 197 |
"\n",
|
|
|
|
| 198 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
| 199 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 200 |
]
|
|
|
|
| 210 |
"\n",
|
| 211 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 212 |
"\n",
|
|
|
|
| 213 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 214 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 215 |
"\n",
|
|
|
|
| 216 |
"train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
|
| 217 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 218 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
|
|
|
| 229 |
"metadata": {},
|
| 230 |
"outputs": [],
|
| 231 |
"source": [
|
| 232 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
| 233 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
| 234 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
| 235 |
]
|
| 236 |
},
|
| 237 |
{
|
|
|
|
| 334 |
" val_accuracy = val_correct / val_total\n",
|
| 335 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 336 |
"\n",
|
|
|
|
| 337 |
" val_losses.append(validation_loss)\n",
|
| 338 |
" val_accs.append(val_accuracy)\n",
|
| 339 |
"\n",
|
|
|
|
| 441 |
"source": [
|
| 442 |
"all_preds = np.array(all_preds)\n",
|
| 443 |
"all_targets = np.array(all_targets)\n",
|
| 444 |
+
"all_images = torch.stack(all_images) \n",
|
| 445 |
"\n",
|
|
|
|
| 446 |
"for class_idx, class_name in enumerate(target_names):\n",
|
| 447 |
+
" print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
|
|
|
|
|
|
|
| 448 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
|
|
|
| 449 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
| 450 |
"\n",
|
| 451 |
" def show_images(indices, title, max_images=5):\n",
|
| 452 |
" num = min(len(indices), max_images)\n",
|
| 453 |
" if num == 0:\n",
|
| 454 |
+
" print(f\"No {title} samples.\")\n",
|
| 455 |
" return\n",
|
| 456 |
"\n",
|
| 457 |
" plt.figure(figsize=(12, 2))\n",
|
| 458 |
" for i, idx in enumerate(indices[:num]):\n",
|
| 459 |
" img = all_images[idx]\n",
|
| 460 |
+
" img = img.permute(1, 2, 0).numpy()\n",
|
| 461 |
" plt.subplot(1, num, i + 1)\n",
|
| 462 |
+
" plt.imshow((img - img.min()) / (img.max() - img.min()))\n",
|
| 463 |
" plt.axis('off')\n",
|
| 464 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
| 465 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
|
|
| 486 |
" activations[name] = output.detach().cpu()\n",
|
| 487 |
" return hook\n",
|
| 488 |
"\n",
|
|
|
|
| 489 |
" hooks = []\n",
|
| 490 |
" for i in range(len(model.features)):\n",
|
| 491 |
" layer = model.features[i]\n",
|
| 492 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 493 |
"\n",
|
| 494 |
" with torch.no_grad():\n",
|
| 495 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
| 496 |
"\n",
|
| 497 |
" for h in hooks:\n",
|
| 498 |
" h.remove()\n",
|
| 499 |
"\n",
|
| 500 |
" for layer_name, fmap in activations.items():\n",
|
| 501 |
+
" fmap = fmap.squeeze(0) \n",
|
| 502 |
"\n",
|
| 503 |
+
" channel_scores = fmap.mean(dim=(1, 2))\n",
|
|
|
|
| 504 |
"\n",
|
|
|
|
| 505 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 506 |
" top_indices = topk.indices\n",
|
| 507 |
"\n",
|
|
|
|
| 508 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 509 |
" for idx, ch in enumerate(top_indices):\n",
|
| 510 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
|
| 537 |
"\n",
|
| 538 |
"img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
|
| 539 |
"\n",
|
|
|
|
| 540 |
"transform = transforms.Compose([\n",
|
| 541 |
" transforms.Resize((224, 224)),\n",
|
| 542 |
" transforms.ToTensor()\n",
|
| 543 |
"])\n",
|
| 544 |
+
"img_tensor = transform(img) \n",
|
| 545 |
"\n",
|
|
|
|
| 546 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 547 |
]
|
| 548 |
},
|
|
|
|
| 556 |
"\n",
|
| 557 |
"img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
|
| 558 |
"\n",
|
|
|
|
| 559 |
"transform = transforms.Compose([\n",
|
| 560 |
" transforms.Resize((224, 224)),\n",
|
| 561 |
" transforms.ToTensor()\n",
|
| 562 |
"])\n",
|
| 563 |
+
"img_tensor = transform(img) \n",
|
| 564 |
"\n",
|
|
|
|
| 565 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 566 |
]
|
| 567 |
},
|
|
|
|
| 575 |
"\n",
|
| 576 |
"img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
|
| 577 |
"\n",
|
|
|
|
| 578 |
"transform = transforms.Compose([\n",
|
| 579 |
" transforms.Resize((224, 224)),\n",
|
| 580 |
" transforms.ToTensor()\n",
|
| 581 |
"])\n",
|
| 582 |
+
"img_tensor = transform(img) \n",
|
| 583 |
"\n",
|
| 584 |
+
"visualize_channels(model, img_tensor, max_channels=16)"
|
|
|
|
| 585 |
]
|
| 586 |
},
|
| 587 |
{
|
scripts/CV/script_tomato.ipynb
CHANGED
|
@@ -13,7 +13,6 @@
|
|
| 13 |
"import matplotlib.pyplot as plt\n",
|
| 14 |
"import random\n",
|
| 15 |
"import torch\n",
|
| 16 |
-
"import numpy as np\n",
|
| 17 |
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
| 18 |
"from sklearn.preprocessing import LabelEncoder\n",
|
| 19 |
"from sklearn.model_selection import train_test_split\n",
|
|
@@ -59,10 +58,10 @@
|
|
| 59 |
"def augment_rotations(X, y):\n",
|
| 60 |
" X_aug = []\n",
|
| 61 |
" y_aug = []\n",
|
| 62 |
-
" for k in [1, 2, 3]:
|
| 63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
| 64 |
" X_aug.append(X_rot)\n",
|
| 65 |
-
" y_aug.append(y.clone())
|
| 66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 67 |
]
|
| 68 |
},
|
|
@@ -103,8 +102,7 @@
|
|
| 103 |
"metadata": {},
|
| 104 |
"outputs": [],
|
| 105 |
"source": [
|
| 106 |
-
"
|
| 107 |
-
"import random\n",
|
| 108 |
"datasets = {\n",
|
| 109 |
" \"diced\": tomato_diced_images,\n",
|
| 110 |
" \"vines\": tomato_vines_images,\n",
|
|
@@ -124,7 +122,6 @@
|
|
| 124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 125 |
" plt.show()\n",
|
| 126 |
"\n",
|
| 127 |
-
"# Display for each class\n",
|
| 128 |
"for class_name, image_array in datasets.items():\n",
|
| 129 |
" show_random_samples(image_array, class_name)\n"
|
| 130 |
]
|
|
@@ -140,7 +137,7 @@
|
|
| 140 |
"\n",
|
| 141 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 142 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 143 |
-
" ax.label_outer()
|
| 144 |
"\n",
|
| 145 |
"plt.tight_layout()\n",
|
| 146 |
"plt.show()"
|
|
@@ -156,7 +153,7 @@
|
|
| 156 |
"class_names = list(datasets.keys())\n",
|
| 157 |
"num_classes = len(class_names)\n",
|
| 158 |
"\n",
|
| 159 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
| 160 |
"\n",
|
| 161 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 162 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
@@ -175,20 +172,12 @@
|
|
| 175 |
"metadata": {},
|
| 176 |
"outputs": [],
|
| 177 |
"source": [
|
| 178 |
-
"import torch\n",
|
| 179 |
-
"import numpy as np\n",
|
| 180 |
-
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
| 181 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
| 182 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 183 |
-
"from torchvision import transforms\n",
|
| 184 |
-
"\n",
|
| 185 |
"datasets = {\n",
|
| 186 |
" \"diced\": tomato_diced_images,\n",
|
| 187 |
" \"vines\": tomato_vines_images,\n",
|
| 188 |
" \"whole\": tomato_whole_images\n",
|
| 189 |
"}\n",
|
| 190 |
"\n",
|
| 191 |
-
"# Combine data\n",
|
| 192 |
"X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
|
| 193 |
"y = (\n",
|
| 194 |
" ['diced'] * len(tomato_diced_images) +\n",
|
|
@@ -196,17 +185,14 @@
|
|
| 196 |
" ['whole'] * len(tomato_whole_images)\n",
|
| 197 |
")\n",
|
| 198 |
"\n",
|
| 199 |
-
"# Normalize and convert to torch tensors\n",
|
| 200 |
"X = X.astype(np.float32) / 255.0\n",
|
| 201 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
| 202 |
"X_tensor = torch.tensor(X)\n",
|
| 203 |
"\n",
|
| 204 |
-
"# Encode labels\n",
|
| 205 |
"le = LabelEncoder()\n",
|
| 206 |
"y_encoded = le.fit_transform(y)\n",
|
| 207 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 208 |
"\n",
|
| 209 |
-
"# Train/val/test split\n",
|
| 210 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
| 211 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 212 |
]
|
|
@@ -222,17 +208,13 @@
|
|
| 222 |
"\n",
|
| 223 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 224 |
"\n",
|
| 225 |
-
"# Combine original and augmented data\n",
|
| 226 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 227 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 228 |
"\n",
|
| 229 |
-
"# Create new training dataset and loader\n",
|
| 230 |
-
"\n",
|
| 231 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
| 232 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 233 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 234 |
"\n",
|
| 235 |
-
"# DataLoaders\n",
|
| 236 |
"\n",
|
| 237 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 238 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
@@ -246,9 +228,9 @@
|
|
| 246 |
"metadata": {},
|
| 247 |
"outputs": [],
|
| 248 |
"source": [
|
| 249 |
-
"print(f\"
|
| 250 |
-
"print(f\"
|
| 251 |
-
"print(f\"
|
| 252 |
]
|
| 253 |
},
|
| 254 |
{
|
|
@@ -258,18 +240,9 @@
|
|
| 258 |
"metadata": {},
|
| 259 |
"outputs": [],
|
| 260 |
"source": [
|
| 261 |
-
"import torch.nn as nn\n",
|
| 262 |
-
"import torch.nn.functional as F\n",
|
| 263 |
-
"\n",
|
| 264 |
-
"import torch.nn as nn\n",
|
| 265 |
-
"import torchvision.models as models\n",
|
| 266 |
-
"\n",
|
| 267 |
"def get_efficientnet_model(num_classes):\n",
|
| 268 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
| 269 |
-
"\n",
|
| 270 |
-
" # Replace classifier head with custom head\n",
|
| 271 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
| 272 |
-
"\n",
|
| 273 |
" return model\n",
|
| 274 |
"\n"
|
| 275 |
]
|
|
@@ -283,10 +256,10 @@
|
|
| 283 |
"source": [
|
| 284 |
"if torch.backends.mps.is_available():\n",
|
| 285 |
" device = torch.device(\"mps\")\n",
|
| 286 |
-
" print(\"
|
| 287 |
"else:\n",
|
| 288 |
" device = torch.device(\"cpu\")\n",
|
| 289 |
-
" print(\"
|
| 290 |
"\n",
|
| 291 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
| 292 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
@@ -329,7 +302,6 @@
|
|
| 329 |
"\n",
|
| 330 |
" total_train_loss += loss.item()\n",
|
| 331 |
"\n",
|
| 332 |
-
" # Track training accuracy\n",
|
| 333 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 334 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 335 |
" train_total += batch_y.size(0)\n",
|
|
@@ -353,7 +325,6 @@
|
|
| 353 |
" val_accuracy = val_correct / val_total\n",
|
| 354 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 355 |
"\n",
|
| 356 |
-
" # After calculating val_accuracy\n",
|
| 357 |
" val_losses.append(validation_loss)\n",
|
| 358 |
" val_accs.append(val_accuracy)\n",
|
| 359 |
"\n",
|
|
@@ -381,13 +352,12 @@
|
|
| 381 |
"metadata": {},
|
| 382 |
"outputs": [],
|
| 383 |
"source": [
|
| 384 |
-
"
|
| 385 |
"\n",
|
| 386 |
"epochs = range(1, len(train_losses) + 1)\n",
|
| 387 |
"\n",
|
| 388 |
"plt.figure(figsize=(12, 5))\n",
|
| 389 |
"\n",
|
| 390 |
-
"# Plot Loss\n",
|
| 391 |
"plt.subplot(1, 2, 1)\n",
|
| 392 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
| 393 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
@@ -397,7 +367,6 @@
|
|
| 397 |
"plt.legend()\n",
|
| 398 |
"plt.grid(True)\n",
|
| 399 |
"\n",
|
| 400 |
-
"# Plot Accuracy\n",
|
| 401 |
"plt.subplot(1, 2, 2)\n",
|
| 402 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
| 403 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
@@ -462,35 +431,28 @@
|
|
| 462 |
"metadata": {},
|
| 463 |
"outputs": [],
|
| 464 |
"source": [
|
| 465 |
-
"import torch\n",
|
| 466 |
-
"import numpy as np\n",
|
| 467 |
-
"import matplotlib.pyplot as plt\n",
|
| 468 |
"\n",
|
| 469 |
"all_preds = np.array(all_preds)\n",
|
| 470 |
"all_targets = np.array(all_targets)\n",
|
| 471 |
-
"all_images = torch.stack(all_images)
|
| 472 |
"\n",
|
| 473 |
-
"# Per class FP and FN\n",
|
| 474 |
"for class_idx, class_name in enumerate(target_names):\n",
|
| 475 |
-
" print(f\"\\
|
| 476 |
-
"\n",
|
| 477 |
-
" # False Negatives: True label is class_idx, but predicted something else\n",
|
| 478 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
| 479 |
-
" # False Positives: Predicted class_idx, but true label is different\n",
|
| 480 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
| 481 |
"\n",
|
| 482 |
" def show_images(indices, title, max_images=5):\n",
|
| 483 |
" num = min(len(indices), max_images)\n",
|
| 484 |
" if num == 0:\n",
|
| 485 |
-
" print(f\"
|
| 486 |
" return\n",
|
| 487 |
"\n",
|
| 488 |
" plt.figure(figsize=(12, 2))\n",
|
| 489 |
" for i, idx in enumerate(indices[:num]):\n",
|
| 490 |
" img = all_images[idx]\n",
|
| 491 |
-
" img = img.permute(1, 2, 0).numpy()
|
| 492 |
" plt.subplot(1, num, i + 1)\n",
|
| 493 |
-
" plt.imshow((img - img.min()) / (img.max() - img.min()))
|
| 494 |
" plt.axis('off')\n",
|
| 495 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
| 496 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
@@ -517,29 +479,25 @@
|
|
| 517 |
" activations[name] = output.detach().cpu()\n",
|
| 518 |
" return hook\n",
|
| 519 |
"\n",
|
| 520 |
-
" # Register hooks for all layers in model.features\n",
|
| 521 |
" hooks = []\n",
|
| 522 |
" for i in range(len(model.features)):\n",
|
| 523 |
" layer = model.features[i]\n",
|
| 524 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 525 |
"\n",
|
| 526 |
" with torch.no_grad():\n",
|
| 527 |
-
" _ = model(image_tensor.unsqueeze(0))
|
| 528 |
"\n",
|
| 529 |
" for h in hooks:\n",
|
| 530 |
" h.remove()\n",
|
| 531 |
"\n",
|
| 532 |
" for layer_name, fmap in activations.items():\n",
|
| 533 |
-
" fmap = fmap.squeeze(0)
|
| 534 |
"\n",
|
| 535 |
-
"
|
| 536 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
| 537 |
"\n",
|
| 538 |
-
" # Get indices of top-k channels\n",
|
| 539 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 540 |
" top_indices = topk.indices\n",
|
| 541 |
"\n",
|
| 542 |
-
" # Plot top-k channels\n",
|
| 543 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 544 |
" for idx, ch in enumerate(top_indices):\n",
|
| 545 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
@@ -572,14 +530,12 @@
|
|
| 572 |
"\n",
|
| 573 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
| 574 |
"\n",
|
| 575 |
-
"# Preprocessing (must match model requirements)\n",
|
| 576 |
"transform = transforms.Compose([\n",
|
| 577 |
" transforms.Resize((224, 224)),\n",
|
| 578 |
" transforms.ToTensor()\n",
|
| 579 |
"])\n",
|
| 580 |
-
"img_tensor = transform(img)
|
| 581 |
"\n",
|
| 582 |
-
"# Visualize feature maps\n",
|
| 583 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 584 |
]
|
| 585 |
},
|
|
@@ -592,14 +548,12 @@
|
|
| 592 |
"source": [
|
| 593 |
"img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
|
| 594 |
"\n",
|
| 595 |
-
"# Preprocessing (must match model requirements)\n",
|
| 596 |
"transform = transforms.Compose([\n",
|
| 597 |
" transforms.Resize((224, 224)),\n",
|
| 598 |
" transforms.ToTensor()\n",
|
| 599 |
"])\n",
|
| 600 |
-
"img_tensor = transform(img)
|
| 601 |
"\n",
|
| 602 |
-
"# Visualize feature maps\n",
|
| 603 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 604 |
]
|
| 605 |
},
|
|
@@ -612,14 +566,12 @@
|
|
| 612 |
"source": [
|
| 613 |
"img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
|
| 614 |
"\n",
|
| 615 |
-
"# Preprocessing (must match model requirements)\n",
|
| 616 |
"transform = transforms.Compose([\n",
|
| 617 |
" transforms.Resize((224, 224)),\n",
|
| 618 |
" transforms.ToTensor()\n",
|
| 619 |
"])\n",
|
| 620 |
-
"img_tensor = transform(img)
|
| 621 |
"\n",
|
| 622 |
-
"# Visualize feature maps\n",
|
| 623 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 624 |
]
|
| 625 |
},
|
|
|
|
| 13 |
"import matplotlib.pyplot as plt\n",
|
| 14 |
"import random\n",
|
| 15 |
"import torch\n",
|
|
|
|
| 16 |
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
| 17 |
"from sklearn.preprocessing import LabelEncoder\n",
|
| 18 |
"from sklearn.model_selection import train_test_split\n",
|
|
|
|
| 58 |
"def augment_rotations(X, y):\n",
|
| 59 |
" X_aug = []\n",
|
| 60 |
" y_aug = []\n",
|
| 61 |
+
" for k in [1, 2, 3]: \n",
|
| 62 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
| 63 |
" X_aug.append(X_rot)\n",
|
| 64 |
+
" y_aug.append(y.clone()) \n",
|
| 65 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
| 66 |
]
|
| 67 |
},
|
|
|
|
| 102 |
"metadata": {},
|
| 103 |
"outputs": [],
|
| 104 |
"source": [
|
| 105 |
+
"\n",
|
|
|
|
| 106 |
"datasets = {\n",
|
| 107 |
" \"diced\": tomato_diced_images,\n",
|
| 108 |
" \"vines\": tomato_vines_images,\n",
|
|
|
|
| 122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
| 123 |
" plt.show()\n",
|
| 124 |
"\n",
|
|
|
|
| 125 |
"for class_name, image_array in datasets.items():\n",
|
| 126 |
" show_random_samples(image_array, class_name)\n"
|
| 127 |
]
|
|
|
|
| 137 |
"\n",
|
| 138 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
| 139 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
| 140 |
+
" ax.label_outer() \n",
|
| 141 |
"\n",
|
| 142 |
"plt.tight_layout()\n",
|
| 143 |
"plt.show()"
|
|
|
|
| 153 |
"class_names = list(datasets.keys())\n",
|
| 154 |
"num_classes = len(class_names)\n",
|
| 155 |
"\n",
|
| 156 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
| 157 |
"\n",
|
| 158 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
| 159 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
|
| 172 |
"metadata": {},
|
| 173 |
"outputs": [],
|
| 174 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
"datasets = {\n",
|
| 176 |
" \"diced\": tomato_diced_images,\n",
|
| 177 |
" \"vines\": tomato_vines_images,\n",
|
| 178 |
" \"whole\": tomato_whole_images\n",
|
| 179 |
"}\n",
|
| 180 |
"\n",
|
|
|
|
| 181 |
"X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
|
| 182 |
"y = (\n",
|
| 183 |
" ['diced'] * len(tomato_diced_images) +\n",
|
|
|
|
| 185 |
" ['whole'] * len(tomato_whole_images)\n",
|
| 186 |
")\n",
|
| 187 |
"\n",
|
|
|
|
| 188 |
"X = X.astype(np.float32) / 255.0\n",
|
| 189 |
+
"X = np.transpose(X, (0, 3, 1, 2))\n",
|
| 190 |
"X_tensor = torch.tensor(X)\n",
|
| 191 |
"\n",
|
|
|
|
| 192 |
"le = LabelEncoder()\n",
|
| 193 |
"y_encoded = le.fit_transform(y)\n",
|
| 194 |
"y_tensor = torch.tensor(y_encoded)\n",
|
| 195 |
"\n",
|
|
|
|
| 196 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
| 197 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
| 198 |
]
|
|
|
|
| 208 |
"\n",
|
| 209 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
| 210 |
"\n",
|
|
|
|
| 211 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
| 212 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
| 213 |
"\n",
|
|
|
|
|
|
|
| 214 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
| 215 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
| 216 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
| 217 |
"\n",
|
|
|
|
| 218 |
"\n",
|
| 219 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
| 220 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
|
|
| 228 |
"metadata": {},
|
| 229 |
"outputs": [],
|
| 230 |
"source": [
|
| 231 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
| 232 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
| 233 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
| 234 |
]
|
| 235 |
},
|
| 236 |
{
|
|
|
|
| 240 |
"metadata": {},
|
| 241 |
"outputs": [],
|
| 242 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
"def get_efficientnet_model(num_classes):\n",
|
| 244 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
|
|
|
|
|
|
| 245 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
|
|
|
| 246 |
" return model\n",
|
| 247 |
"\n"
|
| 248 |
]
|
|
|
|
| 256 |
"source": [
|
| 257 |
"if torch.backends.mps.is_available():\n",
|
| 258 |
" device = torch.device(\"mps\")\n",
|
| 259 |
+
" print(\"Using MPS (Apple GPU)\")\n",
|
| 260 |
"else:\n",
|
| 261 |
" device = torch.device(\"cpu\")\n",
|
| 262 |
+
" print(\"MPS not available. Using CPU\")\n",
|
| 263 |
"\n",
|
| 264 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
| 265 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
|
|
| 302 |
"\n",
|
| 303 |
" total_train_loss += loss.item()\n",
|
| 304 |
"\n",
|
|
|
|
| 305 |
" pred_labels = preds.argmax(dim=1)\n",
|
| 306 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
| 307 |
" train_total += batch_y.size(0)\n",
|
|
|
|
| 325 |
" val_accuracy = val_correct / val_total\n",
|
| 326 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
| 327 |
"\n",
|
|
|
|
| 328 |
" val_losses.append(validation_loss)\n",
|
| 329 |
" val_accs.append(val_accuracy)\n",
|
| 330 |
"\n",
|
|
|
|
| 352 |
"metadata": {},
|
| 353 |
"outputs": [],
|
| 354 |
"source": [
|
| 355 |
+
"\n",
|
| 356 |
"\n",
|
| 357 |
"epochs = range(1, len(train_losses) + 1)\n",
|
| 358 |
"\n",
|
| 359 |
"plt.figure(figsize=(12, 5))\n",
|
| 360 |
"\n",
|
|
|
|
| 361 |
"plt.subplot(1, 2, 1)\n",
|
| 362 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
| 363 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
|
|
| 367 |
"plt.legend()\n",
|
| 368 |
"plt.grid(True)\n",
|
| 369 |
"\n",
|
|
|
|
| 370 |
"plt.subplot(1, 2, 2)\n",
|
| 371 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
| 372 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
|
|
| 431 |
"metadata": {},
|
| 432 |
"outputs": [],
|
| 433 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 434 |
"\n",
|
| 435 |
"all_preds = np.array(all_preds)\n",
|
| 436 |
"all_targets = np.array(all_targets)\n",
|
| 437 |
+
"all_images = torch.stack(all_images) \n",
|
| 438 |
"\n",
|
|
|
|
| 439 |
"for class_idx, class_name in enumerate(target_names):\n",
|
| 440 |
+
" print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
|
|
|
|
|
|
|
| 441 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
|
|
|
| 442 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
| 443 |
"\n",
|
| 444 |
" def show_images(indices, title, max_images=5):\n",
|
| 445 |
" num = min(len(indices), max_images)\n",
|
| 446 |
" if num == 0:\n",
|
| 447 |
+
" print(f\"No {title} samples.\")\n",
|
| 448 |
" return\n",
|
| 449 |
"\n",
|
| 450 |
" plt.figure(figsize=(12, 2))\n",
|
| 451 |
" for i, idx in enumerate(indices[:num]):\n",
|
| 452 |
" img = all_images[idx]\n",
|
| 453 |
+
" img = img.permute(1, 2, 0).numpy()\n",
|
| 454 |
" plt.subplot(1, num, i + 1)\n",
|
| 455 |
+
" plt.imshow((img - img.min()) / (img.max() - img.min())) \n",
|
| 456 |
" plt.axis('off')\n",
|
| 457 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
| 458 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
|
|
| 479 |
" activations[name] = output.detach().cpu()\n",
|
| 480 |
" return hook\n",
|
| 481 |
"\n",
|
|
|
|
| 482 |
" hooks = []\n",
|
| 483 |
" for i in range(len(model.features)):\n",
|
| 484 |
" layer = model.features[i]\n",
|
| 485 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
| 486 |
"\n",
|
| 487 |
" with torch.no_grad():\n",
|
| 488 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
| 489 |
"\n",
|
| 490 |
" for h in hooks:\n",
|
| 491 |
" h.remove()\n",
|
| 492 |
"\n",
|
| 493 |
" for layer_name, fmap in activations.items():\n",
|
| 494 |
+
" fmap = fmap.squeeze(0) \n",
|
| 495 |
"\n",
|
| 496 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
|
| 497 |
"\n",
|
|
|
|
| 498 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
| 499 |
" top_indices = topk.indices\n",
|
| 500 |
"\n",
|
|
|
|
| 501 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
| 502 |
" for idx, ch in enumerate(top_indices):\n",
|
| 503 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
|
| 530 |
"\n",
|
| 531 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
| 532 |
"\n",
|
|
|
|
| 533 |
"transform = transforms.Compose([\n",
|
| 534 |
" transforms.Resize((224, 224)),\n",
|
| 535 |
" transforms.ToTensor()\n",
|
| 536 |
"])\n",
|
| 537 |
+
"img_tensor = transform(img) \n",
|
| 538 |
"\n",
|
|
|
|
| 539 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 540 |
]
|
| 541 |
},
|
|
|
|
| 548 |
"source": [
|
| 549 |
"img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
|
| 550 |
"\n",
|
|
|
|
| 551 |
"transform = transforms.Compose([\n",
|
| 552 |
" transforms.Resize((224, 224)),\n",
|
| 553 |
" transforms.ToTensor()\n",
|
| 554 |
"])\n",
|
| 555 |
+
"img_tensor = transform(img) \n",
|
| 556 |
"\n",
|
|
|
|
| 557 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 558 |
]
|
| 559 |
},
|
|
|
|
| 566 |
"source": [
|
| 567 |
"img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
|
| 568 |
"\n",
|
|
|
|
| 569 |
"transform = transforms.Compose([\n",
|
| 570 |
" transforms.Resize((224, 224)),\n",
|
| 571 |
" transforms.ToTensor()\n",
|
| 572 |
"])\n",
|
| 573 |
+
"img_tensor = transform(img) \n",
|
| 574 |
"\n",
|
|
|
|
| 575 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
| 576 |
]
|
| 577 |
},
|