biengsen4real commited on
Commit
15327b9
·
verified ·
1 Parent(s): 3049a1c

Upload predict.ipynb

Browse files
Files changed (1) hide show
  1. predict.ipynb +179 -0
predict.ipynb ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "source": [
20
+ "import os\n",
21
+ "import torch\n",
22
+ "from transformers import ViTForImageClassification, ViTFeatureExtractor\n",
23
+ "from PIL import Image\n",
24
+ "import io\n",
25
+ "import pandas as pd\n",
26
+ "\n",
27
+ "def load_model(model_path):\n",
28
+ " \"\"\"Load the pre-trained model and feature extractor.\"\"\"\n",
29
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
30
+ " feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')\n",
31
+ "\n",
32
+ " # Load the model with weights mapped to the appropriate device\n",
33
+ " model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)\n",
34
+ "\n",
35
+ " # Load the state dict with map_location to handle devices properly\n",
36
+ " state_dict = torch.load(model_path, map_location=device)\n",
37
+ " model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore size mismatches\n",
38
+ "\n",
39
+ " model = model.to(device)\n",
40
+ " model.eval() # Set the model to evaluation mode\n",
41
+ " return model, feature_extractor, device\n",
42
+ "\n",
43
+ "def safe_load_image(path):\n",
44
+ " \"\"\"Safely load an image, handling potential errors.\"\"\"\n",
45
+ " try:\n",
46
+ " with open(path, 'rb') as f:\n",
47
+ " img = Image.open(io.BytesIO(f.read()))\n",
48
+ " img = img.convert('RGB')\n",
49
+ " img = img.resize((224, 224)) # Resize the image to (224, 224)\n",
50
+ " return img\n",
51
+ " except Exception as e:\n",
52
+ " print(f\"Error loading image {path}: {e}\")\n",
53
+ " return None\n",
54
+ "\n",
55
+ "def predict_image_class(image_path, model, feature_extractor, device, class_names):\n",
56
+ " \"\"\"Predict the class of a given image.\"\"\"\n",
57
+ " img = safe_load_image(image_path)\n",
58
+ " if img is None:\n",
59
+ " return None, None\n",
60
+ "\n",
61
+ " # Preprocess the image\n",
62
+ " inputs = feature_extractor(images=img, return_tensors=\"pt\").to(device)\n",
63
+ "\n",
64
+ " # Perform prediction\n",
65
+ " with torch.no_grad():\n",
66
+ " outputs = model(**inputs).logits\n",
67
+ " probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities\n",
68
+ " predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index\n",
69
+ " predicted_class = class_names[predicted_class_idx] # Get the class name based on the index\n",
70
+ "\n",
71
+ " return predicted_class, probabilities\n",
72
+ "\n",
73
+ "def predict_images_in_folder(folder_path, model, feature_extractor, device, class_names):\n",
74
+ " \"\"\"Predict the class of each image in a folder.\"\"\"\n",
75
+ " results = []\n",
76
+ " for filename in os.listdir(folder_path):\n",
77
+ " if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):\n",
78
+ " image_path = os.path.join(folder_path, filename)\n",
79
+ " predicted_class, probabilities = predict_image_class(image_path, model, feature_extractor, device, class_names)\n",
80
+ " if predicted_class is not None:\n",
81
+ " results.append({'Image Name': filename, 'Predicted Class': predicted_class, 'Probabilities': probabilities})\n",
82
+ "\n",
83
+ " return results\n",
84
+ "\n",
85
+ "def save_results_to_excel(results, output_file, class_names):\n",
86
+ " \"\"\"Save prediction results to an Excel file.\"\"\"\n",
87
+ " # Flatten probability array and create DataFrame\n",
88
+ " rows = []\n",
89
+ " for result in results:\n",
90
+ " # Add each probability with corresponding class name\n",
91
+ " for idx, prob in enumerate(result['Probabilities']):\n",
92
+ " rows.append({\n",
93
+ " 'Image Name': result['Image Name'],\n",
94
+ " 'Predicted Class': result['Predicted Class'],\n",
95
+ " 'Class': class_names[idx],\n",
96
+ " 'Probability': prob\n",
97
+ " })\n",
98
+ "\n",
99
+ " df = pd.DataFrame(rows)\n",
100
+ "\n",
101
+ " # Sort by probability in descending order\n",
102
+ " df = df.sort_values(by='Probability', ascending=False)\n",
103
+ "\n",
104
+ " # Save to Excel\n",
105
+ " df.to_excel(output_file, index=False)\n",
106
+ " print(f'Results saved to {output_file}') # Confirm saving\n",
107
+ "\n",
108
+ "def main(input_path, model_path, output_file):\n",
109
+ " \"\"\"Main function to execute image classification predictions, processing single images or folders, and saving results to Excel.\"\"\"\n",
110
+ " class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',\n",
111
+ " 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',\n",
112
+ " 'meeting', 'speech', 'refugee', 'victory']\n",
113
+ "\n",
114
+ " model, feature_extractor, device = load_model(model_path)\n",
115
+ "\n",
116
+ " if os.path.isdir(input_path):\n",
117
+ " # If the input path is a folder, predict all images in that folder\n",
118
+ " results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)\n",
119
+ " if results:\n",
120
+ " save_results_to_excel(results, output_file, class_names)\n",
121
+ " else:\n",
122
+ " print(\"No valid images found in the specified folder.\")\n",
123
+ " elif os.path.isfile(input_path):\n",
124
+ " # If the input path is a single image, perform direct prediction\n",
125
+ " predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)\n",
126
+ " if predicted_class is not None:\n",
127
+ " print(f'Predicted class for image {os.path.basename(input_path)}: {predicted_class}')\n",
128
+ " else:\n",
129
+ " print(\"Image could not be processed.\")\n",
130
+ " else:\n",
131
+ " print('Invalid input path. Please provide a valid file or folder path.')\n"
132
+ ],
133
+ "metadata": {
134
+ "id": "340nVjm4AcDO"
135
+ },
136
+ "execution_count": 15,
137
+ "outputs": []
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "source": [
142
+ "# Example call\n",
143
+ "input_path = '/content/ddd.jpg' # Replace with your image or folder path\n",
144
+ "model_path = '/content/model.pth' # Replace with your model path\n",
145
+ "output_file = '/content/predictions.xlsx' # Name of the output Excel file\n",
146
+ "main(input_path, model_path, output_file)"
147
+ ],
148
+ "metadata": {
149
+ "id": "CY-fkhjdAeMM",
150
+ "outputId": "1de8f113-97a5-419d-995b-815ec5391a80",
151
+ "colab": {
152
+ "base_uri": "https://localhost:8080/"
153
+ }
154
+ },
155
+ "execution_count": 18,
156
+ "outputs": [
157
+ {
158
+ "output_type": "stream",
159
+ "name": "stderr",
160
+ "text": [
161
+ "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n",
162
+ "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([13]) in the model instantiated\n",
163
+ "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([13, 768]) in the model instantiated\n",
164
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
165
+ "<ipython-input-15-1637f10c4a55>:17: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
166
+ " state_dict = torch.load(model_path, map_location=device)\n"
167
+ ]
168
+ },
169
+ {
170
+ "output_type": "stream",
171
+ "name": "stdout",
172
+ "text": [
173
+ "Predicted class for image ddd.jpg: infrastructure\n"
174
+ ]
175
+ }
176
+ ]
177
+ }
178
+ ]
179
+ }