Update predict.py
Browse files- predict.py +36 -31
predict.py
CHANGED
|
@@ -1,12 +1,3 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
"""Untitled5.ipynb
|
| 3 |
-
|
| 4 |
-
Automatically generated by Colab.
|
| 5 |
-
|
| 6 |
-
Original file is located at
|
| 7 |
-
https://colab.research.google.com/drive/1kfJMKD93CT0RxoHkh_T3hdcVMovTMHCe
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
import os
|
| 11 |
import torch
|
| 12 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
|
@@ -18,18 +9,25 @@ def load_model(model_path):
|
|
| 18 |
"""Load the pre-trained model and feature extractor."""
|
| 19 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 20 |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
| 21 |
-
|
| 22 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
model = model.to(device)
|
| 24 |
model.eval() # Set the model to evaluation mode
|
| 25 |
return model, feature_extractor, device
|
| 26 |
|
| 27 |
def safe_load_image(path):
|
| 28 |
-
"""Safely load an image, handling
|
| 29 |
try:
|
| 30 |
with open(path, 'rb') as f:
|
| 31 |
img = Image.open(io.BytesIO(f.read()))
|
| 32 |
img = img.convert('RGB')
|
|
|
|
| 33 |
return img
|
| 34 |
except Exception as e:
|
| 35 |
print(f"Error loading image {path}: {e}")
|
|
@@ -44,12 +42,12 @@ def predict_image_class(image_path, model, feature_extractor, device, class_name
|
|
| 44 |
# Preprocess the image
|
| 45 |
inputs = feature_extractor(images=img, return_tensors="pt").to(device)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
with torch.no_grad():
|
| 49 |
outputs = model(**inputs).logits
|
| 50 |
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
|
| 51 |
predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
|
| 52 |
-
predicted_class = class_names[predicted_class_idx] # Get class name
|
| 53 |
|
| 54 |
return predicted_class, probabilities
|
| 55 |
|
|
@@ -65,31 +63,33 @@ def predict_images_in_folder(folder_path, model, feature_extractor, device, clas
|
|
| 65 |
|
| 66 |
return results
|
| 67 |
|
| 68 |
-
def save_results_to_excel(results, output_file):
|
| 69 |
-
"""Save
|
| 70 |
-
# Flatten
|
| 71 |
rows = []
|
| 72 |
for result in results:
|
|
|
|
| 73 |
for idx, prob in enumerate(result['Probabilities']):
|
| 74 |
rows.append({
|
| 75 |
'Image Name': result['Image Name'],
|
| 76 |
'Predicted Class': result['Predicted Class'],
|
| 77 |
-
'Class
|
| 78 |
'Probability': prob
|
| 79 |
})
|
| 80 |
-
|
| 81 |
df = pd.DataFrame(rows)
|
| 82 |
-
|
| 83 |
# Sort by probability in descending order
|
| 84 |
df = df.sort_values(by='Probability', ascending=False)
|
| 85 |
-
|
| 86 |
# Save to Excel
|
| 87 |
df.to_excel(output_file, index=False)
|
|
|
|
| 88 |
|
| 89 |
def main(input_path, model_path, output_file):
|
| 90 |
-
"""Main function to
|
| 91 |
-
class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
|
| 92 |
-
'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
|
| 93 |
'meeting', 'speech', 'refugee', 'victory']
|
| 94 |
|
| 95 |
model, feature_extractor, device = load_model(model_path)
|
|
@@ -97,17 +97,22 @@ def main(input_path, model_path, output_file):
|
|
| 97 |
if os.path.isdir(input_path):
|
| 98 |
# If the input path is a folder, predict all images in that folder
|
| 99 |
results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
elif os.path.isfile(input_path):
|
| 103 |
-
# If the input path is a single image,
|
| 104 |
predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
print('Invalid input path. Please provide a valid file or folder path.')
|
| 108 |
|
| 109 |
# Example call
|
| 110 |
-
input_path = '/
|
| 111 |
-
model_path = '/
|
| 112 |
-
output_file = 'predictions.xlsx' # Name of the output Excel file
|
| 113 |
main(input_path, model_path, output_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
|
|
|
| 9 |
"""Load the pre-trained model and feature extractor."""
|
| 10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 11 |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
| 12 |
+
|
| 13 |
+
# Load the model with weights mapped to the appropriate device
|
| 14 |
+
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)
|
| 15 |
+
|
| 16 |
+
# Load the state dict with map_location to handle devices properly
|
| 17 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 18 |
+
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore size mismatches
|
| 19 |
+
|
| 20 |
model = model.to(device)
|
| 21 |
model.eval() # Set the model to evaluation mode
|
| 22 |
return model, feature_extractor, device
|
| 23 |
|
| 24 |
def safe_load_image(path):
|
| 25 |
+
"""Safely load an image, handling potential errors."""
|
| 26 |
try:
|
| 27 |
with open(path, 'rb') as f:
|
| 28 |
img = Image.open(io.BytesIO(f.read()))
|
| 29 |
img = img.convert('RGB')
|
| 30 |
+
img = img.resize((224, 224)) # Resize the image to (224, 224)
|
| 31 |
return img
|
| 32 |
except Exception as e:
|
| 33 |
print(f"Error loading image {path}: {e}")
|
|
|
|
| 42 |
# Preprocess the image
|
| 43 |
inputs = feature_extractor(images=img, return_tensors="pt").to(device)
|
| 44 |
|
| 45 |
+
# Perform prediction
|
| 46 |
with torch.no_grad():
|
| 47 |
outputs = model(**inputs).logits
|
| 48 |
probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
|
| 49 |
predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
|
| 50 |
+
predicted_class = class_names[predicted_class_idx] # Get the class name based on the index
|
| 51 |
|
| 52 |
return predicted_class, probabilities
|
| 53 |
|
|
|
|
| 63 |
|
| 64 |
return results
|
| 65 |
|
| 66 |
+
def save_results_to_excel(results, output_file, class_names):
|
| 67 |
+
"""Save prediction results to an Excel file."""
|
| 68 |
+
# Flatten probability array and create DataFrame
|
| 69 |
rows = []
|
| 70 |
for result in results:
|
| 71 |
+
# Add each probability with corresponding class name
|
| 72 |
for idx, prob in enumerate(result['Probabilities']):
|
| 73 |
rows.append({
|
| 74 |
'Image Name': result['Image Name'],
|
| 75 |
'Predicted Class': result['Predicted Class'],
|
| 76 |
+
'Class': class_names[idx],
|
| 77 |
'Probability': prob
|
| 78 |
})
|
| 79 |
+
|
| 80 |
df = pd.DataFrame(rows)
|
| 81 |
+
|
| 82 |
# Sort by probability in descending order
|
| 83 |
df = df.sort_values(by='Probability', ascending=False)
|
| 84 |
+
|
| 85 |
# Save to Excel
|
| 86 |
df.to_excel(output_file, index=False)
|
| 87 |
+
print(f'Results saved to {output_file}') # Confirm saving
|
| 88 |
|
| 89 |
def main(input_path, model_path, output_file):
|
| 90 |
+
"""Main function to execute image classification predictions, processing single images or folders, and saving results to Excel."""
|
| 91 |
+
class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
|
| 92 |
+
'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
|
| 93 |
'meeting', 'speech', 'refugee', 'victory']
|
| 94 |
|
| 95 |
model, feature_extractor, device = load_model(model_path)
|
|
|
|
| 97 |
if os.path.isdir(input_path):
|
| 98 |
# If the input path is a folder, predict all images in that folder
|
| 99 |
results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
|
| 100 |
+
if results:
|
| 101 |
+
save_results_to_excel(results, output_file, class_names)
|
| 102 |
+
else:
|
| 103 |
+
print("No valid images found in the specified folder.")
|
| 104 |
elif os.path.isfile(input_path):
|
| 105 |
+
# If the input path is a single image, perform direct prediction
|
| 106 |
predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
|
| 107 |
+
if predicted_class is not None:
|
| 108 |
+
print(f'Predicted class for image {os.path.basename(input_path)}: {predicted_class}')
|
| 109 |
+
else:
|
| 110 |
+
print("Image could not be processed.")
|
| 111 |
else:
|
| 112 |
print('Invalid input path. Please provide a valid file or folder path.')
|
| 113 |
|
| 114 |
# Example call
|
| 115 |
+
input_path = '/content/ddd.jpg' # Replace with your image or folder path
|
| 116 |
+
model_path = '/content/model.pth' # Replace with your model path
|
| 117 |
+
output_file = '/content/predictions.xlsx' # Name of the output Excel file
|
| 118 |
main(input_path, model_path, output_file)
|