biengsen4real commited on
Commit
cf30643
·
verified ·
1 Parent(s): a47c735

Upload 2 files

Browse files
Files changed (2) hide show
  1. predict.ipynb +141 -0
  2. predict.py +113 -0
predict.ipynb ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "source": [
20
+ "import os\n",
21
+ "import torch\n",
22
+ "from transformers import ViTForImageClassification, ViTFeatureExtractor\n",
23
+ "from PIL import Image\n",
24
+ "import io\n",
25
+ "import pandas as pd\n",
26
+ "\n",
27
+ "def load_model(model_path):\n",
28
+ " \"\"\"Load the pre-trained model and feature extractor.\"\"\"\n",
29
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
30
+ " feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')\n",
31
+ " model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13)\n",
32
+ " model.load_state_dict(torch.load(model_path))\n",
33
+ " model = model.to(device)\n",
34
+ " model.eval() # Set the model to evaluation mode\n",
35
+ " return model, feature_extractor, device\n",
36
+ "\n",
37
+ "def safe_load_image(path):\n",
38
+ " \"\"\"Safely load an image, handling possible errors.\"\"\"\n",
39
+ " try:\n",
40
+ " with open(path, 'rb') as f:\n",
41
+ " img = Image.open(io.BytesIO(f.read()))\n",
42
+ " img = img.convert('RGB')\n",
43
+ " return img\n",
44
+ " except Exception as e:\n",
45
+ " print(f\"Error loading image {path}: {e}\")\n",
46
+ " return None\n",
47
+ "\n",
48
+ "def predict_image_class(image_path, model, feature_extractor, device, class_names):\n",
49
+ " \"\"\"Predict the class of a given image.\"\"\"\n",
50
+ " img = safe_load_image(image_path)\n",
51
+ " if img is None:\n",
52
+ " return None, None\n",
53
+ "\n",
54
+ " # Preprocess the image\n",
55
+ " inputs = feature_extractor(images=img, return_tensors=\"pt\").to(device)\n",
56
+ "\n",
57
+ " # Make the prediction\n",
58
+ " with torch.no_grad():\n",
59
+ " outputs = model(**inputs).logits\n",
60
+ " probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities\n",
61
+ " predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index\n",
62
+ " predicted_class = class_names[predicted_class_idx] # Get class name from index\n",
63
+ "\n",
64
+ " return predicted_class, probabilities\n",
65
+ "\n",
66
+ "def predict_images_in_folder(folder_path, model, feature_extractor, device, class_names):\n",
67
+ " \"\"\"Predict the class of each image in a folder.\"\"\"\n",
68
+ " results = []\n",
69
+ " for filename in os.listdir(folder_path):\n",
70
+ " if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):\n",
71
+ " image_path = os.path.join(folder_path, filename)\n",
72
+ " predicted_class, probabilities = predict_image_class(image_path, model, feature_extractor, device, class_names)\n",
73
+ " if predicted_class is not None:\n",
74
+ " results.append({'Image Name': filename, 'Predicted Class': predicted_class, 'Probabilities': probabilities})\n",
75
+ "\n",
76
+ " return results\n",
77
+ "\n",
78
+ "def save_results_to_excel(results, output_file):\n",
79
+ " \"\"\"Save the prediction results to an Excel file.\"\"\"\n",
80
+ " # Flatten the probabilities array and create a DataFrame\n",
81
+ " rows = []\n",
82
+ " for result in results:\n",
83
+ " for idx, prob in enumerate(result['Probabilities']):\n",
84
+ " rows.append({\n",
85
+ " 'Image Name': result['Image Name'],\n",
86
+ " 'Predicted Class': result['Predicted Class'],\n",
87
+ " 'Class Index': idx,\n",
88
+ " 'Probability': prob\n",
89
+ " })\n",
90
+ "\n",
91
+ " df = pd.DataFrame(rows)\n",
92
+ "\n",
93
+ " # Sort by probability in descending order\n",
94
+ " df = df.sort_values(by='Probability', ascending=False)\n",
95
+ "\n",
96
+ " # Save to Excel\n",
97
+ " df.to_excel(output_file, index=False)\n",
98
+ "\n",
99
+ "def main(input_path, model_path, output_file):\n",
100
+ " \"\"\"Main function to perform image classification prediction, handling single images or folders, and saving results to Excel.\"\"\"\n",
101
+ " class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',\n",
102
+ " 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',\n",
103
+ " 'meeting', 'speech', 'refugee', 'victory']\n",
104
+ "\n",
105
+ " model, feature_extractor, device = load_model(model_path)\n",
106
+ "\n",
107
+ " if os.path.isdir(input_path):\n",
108
+ " # If the input path is a folder, predict all images in that folder\n",
109
+ " results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)\n",
110
+ " save_results_to_excel(results, output_file)\n",
111
+ " print(f'Prediction results saved to: {output_file}')\n",
112
+ " elif os.path.isfile(input_path):\n",
113
+ " # If the input path is a single image, make a direct prediction\n",
114
+ " predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)\n",
115
+ " print(f'The predicted class for image {os.path.basename(input_path)}: {predicted_class}')\n",
116
+ " else:\n",
117
+ " print('Invalid input path. Please provide a valid file or folder path.')\n"
118
+ ],
119
+ "metadata": {
120
+ "id": "340nVjm4AcDO"
121
+ },
122
+ "execution_count": null,
123
+ "outputs": []
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "source": [
128
+ "# Example call\n",
129
+ "input_path = '/path/to/your/image_or_folder' # Replace with your image or folder path\n",
130
+ "model_path = '/kaggle/working/best_modelq.pth' # Replace with your model path\n",
131
+ "output_file = 'predictions.xlsx' # Name of the output Excel file\n",
132
+ "main(input_path, model_path, output_file)"
133
+ ],
134
+ "metadata": {
135
+ "id": "CY-fkhjdAeMM"
136
+ },
137
+ "execution_count": null,
138
+ "outputs": []
139
+ }
140
+ ]
141
+ }
predict.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled5.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1kfJMKD93CT0RxoHkh_T3hdcVMovTMHCe
8
+ """
9
+
10
+ import os
11
+ import torch
12
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
13
+ from PIL import Image
14
+ import io
15
+ import pandas as pd
16
+
17
+ def load_model(model_path):
18
+ """Load the pre-trained model and feature extractor."""
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
21
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13)
22
+ model.load_state_dict(torch.load(model_path))
23
+ model = model.to(device)
24
+ model.eval() # Set the model to evaluation mode
25
+ return model, feature_extractor, device
26
+
27
+ def safe_load_image(path):
28
+ """Safely load an image, handling possible errors."""
29
+ try:
30
+ with open(path, 'rb') as f:
31
+ img = Image.open(io.BytesIO(f.read()))
32
+ img = img.convert('RGB')
33
+ return img
34
+ except Exception as e:
35
+ print(f"Error loading image {path}: {e}")
36
+ return None
37
+
38
+ def predict_image_class(image_path, model, feature_extractor, device, class_names):
39
+ """Predict the class of a given image."""
40
+ img = safe_load_image(image_path)
41
+ if img is None:
42
+ return None, None
43
+
44
+ # Preprocess the image
45
+ inputs = feature_extractor(images=img, return_tensors="pt").to(device)
46
+
47
+ # Make the prediction
48
+ with torch.no_grad():
49
+ outputs = model(**inputs).logits
50
+ probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
51
+ predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
52
+ predicted_class = class_names[predicted_class_idx] # Get class name from index
53
+
54
+ return predicted_class, probabilities
55
+
56
+ def predict_images_in_folder(folder_path, model, feature_extractor, device, class_names):
57
+ """Predict the class of each image in a folder."""
58
+ results = []
59
+ for filename in os.listdir(folder_path):
60
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
61
+ image_path = os.path.join(folder_path, filename)
62
+ predicted_class, probabilities = predict_image_class(image_path, model, feature_extractor, device, class_names)
63
+ if predicted_class is not None:
64
+ results.append({'Image Name': filename, 'Predicted Class': predicted_class, 'Probabilities': probabilities})
65
+
66
+ return results
67
+
68
+ def save_results_to_excel(results, output_file):
69
+ """Save the prediction results to an Excel file."""
70
+ # Flatten the probabilities array and create a DataFrame
71
+ rows = []
72
+ for result in results:
73
+ for idx, prob in enumerate(result['Probabilities']):
74
+ rows.append({
75
+ 'Image Name': result['Image Name'],
76
+ 'Predicted Class': result['Predicted Class'],
77
+ 'Class Index': idx,
78
+ 'Probability': prob
79
+ })
80
+
81
+ df = pd.DataFrame(rows)
82
+
83
+ # Sort by probability in descending order
84
+ df = df.sort_values(by='Probability', ascending=False)
85
+
86
+ # Save to Excel
87
+ df.to_excel(output_file, index=False)
88
+
89
+ def main(input_path, model_path, output_file):
90
+ """Main function to perform image classification prediction, handling single images or folders, and saving results to Excel."""
91
+ class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
92
+ 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
93
+ 'meeting', 'speech', 'refugee', 'victory']
94
+
95
+ model, feature_extractor, device = load_model(model_path)
96
+
97
+ if os.path.isdir(input_path):
98
+ # If the input path is a folder, predict all images in that folder
99
+ results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
100
+ save_results_to_excel(results, output_file)
101
+ print(f'Prediction results saved to: {output_file}')
102
+ elif os.path.isfile(input_path):
103
+ # If the input path is a single image, make a direct prediction
104
+ predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
105
+ print(f'The predicted class for image {os.path.basename(input_path)}: {predicted_class}')
106
+ else:
107
+ print('Invalid input path. Please provide a valid file or folder path.')
108
+
109
+ # Example call
110
+ input_path = '/path/to/your/image_or_folder' # Replace with your image or folder path
111
+ model_path = '/kaggle/working/best_modelq.pth' # Replace with your model path
112
+ output_file = 'predictions.xlsx' # Name of the output Excel file
113
+ main(input_path, model_path, output_file)