biengsen4real commited on
Commit
7e3e39b
·
verified ·
1 Parent(s): 15327b9

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +36 -31
predict.py CHANGED
@@ -1,12 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- """Untitled5.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1kfJMKD93CT0RxoHkh_T3hdcVMovTMHCe
8
- """
9
-
10
  import os
11
  import torch
12
  from transformers import ViTForImageClassification, ViTFeatureExtractor
@@ -18,18 +9,25 @@ def load_model(model_path):
18
  """Load the pre-trained model and feature extractor."""
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
21
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13)
22
- model.load_state_dict(torch.load(model_path))
 
 
 
 
 
 
23
  model = model.to(device)
24
  model.eval() # Set the model to evaluation mode
25
  return model, feature_extractor, device
26
 
27
  def safe_load_image(path):
28
- """Safely load an image, handling possible errors."""
29
  try:
30
  with open(path, 'rb') as f:
31
  img = Image.open(io.BytesIO(f.read()))
32
  img = img.convert('RGB')
 
33
  return img
34
  except Exception as e:
35
  print(f"Error loading image {path}: {e}")
@@ -44,12 +42,12 @@ def predict_image_class(image_path, model, feature_extractor, device, class_name
44
  # Preprocess the image
45
  inputs = feature_extractor(images=img, return_tensors="pt").to(device)
46
 
47
- # Make the prediction
48
  with torch.no_grad():
49
  outputs = model(**inputs).logits
50
  probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
51
  predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
52
- predicted_class = class_names[predicted_class_idx] # Get class name from index
53
 
54
  return predicted_class, probabilities
55
 
@@ -65,31 +63,33 @@ def predict_images_in_folder(folder_path, model, feature_extractor, device, clas
65
 
66
  return results
67
 
68
- def save_results_to_excel(results, output_file):
69
- """Save the prediction results to an Excel file."""
70
- # Flatten the probabilities array and create a DataFrame
71
  rows = []
72
  for result in results:
 
73
  for idx, prob in enumerate(result['Probabilities']):
74
  rows.append({
75
  'Image Name': result['Image Name'],
76
  'Predicted Class': result['Predicted Class'],
77
- 'Class Index': idx,
78
  'Probability': prob
79
  })
80
-
81
  df = pd.DataFrame(rows)
82
-
83
  # Sort by probability in descending order
84
  df = df.sort_values(by='Probability', ascending=False)
85
-
86
  # Save to Excel
87
  df.to_excel(output_file, index=False)
 
88
 
89
  def main(input_path, model_path, output_file):
90
- """Main function to perform image classification prediction, handling single images or folders, and saving results to Excel."""
91
- class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
92
- 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
93
  'meeting', 'speech', 'refugee', 'victory']
94
 
95
  model, feature_extractor, device = load_model(model_path)
@@ -97,17 +97,22 @@ def main(input_path, model_path, output_file):
97
  if os.path.isdir(input_path):
98
  # If the input path is a folder, predict all images in that folder
99
  results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
100
- save_results_to_excel(results, output_file)
101
- print(f'Prediction results saved to: {output_file}')
 
 
102
  elif os.path.isfile(input_path):
103
- # If the input path is a single image, make a direct prediction
104
  predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
105
- print(f'The predicted class for image {os.path.basename(input_path)}: {predicted_class}')
 
 
 
106
  else:
107
  print('Invalid input path. Please provide a valid file or folder path.')
108
 
109
  # Example call
110
- input_path = '/path/to/your/image_or_folder' # Replace with your image or folder path
111
- model_path = '/kaggle/working/best_modelq.pth' # Replace with your model path
112
- output_file = 'predictions.xlsx' # Name of the output Excel file
113
  main(input_path, model_path, output_file)
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  from transformers import ViTForImageClassification, ViTFeatureExtractor
 
9
  """Load the pre-trained model and feature extractor."""
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
12
+
13
+ # Load the model with weights mapped to the appropriate device
14
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)
15
+
16
+ # Load the state dict with map_location to handle devices properly
17
+ state_dict = torch.load(model_path, map_location=device)
18
+ model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore size mismatches
19
+
20
  model = model.to(device)
21
  model.eval() # Set the model to evaluation mode
22
  return model, feature_extractor, device
23
 
24
  def safe_load_image(path):
25
+ """Safely load an image, handling potential errors."""
26
  try:
27
  with open(path, 'rb') as f:
28
  img = Image.open(io.BytesIO(f.read()))
29
  img = img.convert('RGB')
30
+ img = img.resize((224, 224)) # Resize the image to (224, 224)
31
  return img
32
  except Exception as e:
33
  print(f"Error loading image {path}: {e}")
 
42
  # Preprocess the image
43
  inputs = feature_extractor(images=img, return_tensors="pt").to(device)
44
 
45
+ # Perform prediction
46
  with torch.no_grad():
47
  outputs = model(**inputs).logits
48
  probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
49
  predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
50
+ predicted_class = class_names[predicted_class_idx] # Get the class name based on the index
51
 
52
  return predicted_class, probabilities
53
 
 
63
 
64
  return results
65
 
66
+ def save_results_to_excel(results, output_file, class_names):
67
+ """Save prediction results to an Excel file."""
68
+ # Flatten probability array and create DataFrame
69
  rows = []
70
  for result in results:
71
+ # Add each probability with corresponding class name
72
  for idx, prob in enumerate(result['Probabilities']):
73
  rows.append({
74
  'Image Name': result['Image Name'],
75
  'Predicted Class': result['Predicted Class'],
76
+ 'Class': class_names[idx],
77
  'Probability': prob
78
  })
79
+
80
  df = pd.DataFrame(rows)
81
+
82
  # Sort by probability in descending order
83
  df = df.sort_values(by='Probability', ascending=False)
84
+
85
  # Save to Excel
86
  df.to_excel(output_file, index=False)
87
+ print(f'Results saved to {output_file}') # Confirm saving
88
 
89
  def main(input_path, model_path, output_file):
90
+ """Main function to execute image classification predictions, processing single images or folders, and saving results to Excel."""
91
+ class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
92
+ 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
93
  'meeting', 'speech', 'refugee', 'victory']
94
 
95
  model, feature_extractor, device = load_model(model_path)
 
97
  if os.path.isdir(input_path):
98
  # If the input path is a folder, predict all images in that folder
99
  results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
100
+ if results:
101
+ save_results_to_excel(results, output_file, class_names)
102
+ else:
103
+ print("No valid images found in the specified folder.")
104
  elif os.path.isfile(input_path):
105
+ # If the input path is a single image, perform direct prediction
106
  predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
107
+ if predicted_class is not None:
108
+ print(f'Predicted class for image {os.path.basename(input_path)}: {predicted_class}')
109
+ else:
110
+ print("Image could not be processed.")
111
  else:
112
  print('Invalid input path. Please provide a valid file or folder path.')
113
 
114
  # Example call
115
+ input_path = '/content/ddd.jpg' # Replace with your image or folder path
116
+ model_path = '/content/model.pth' # Replace with your model path
117
+ output_file = '/content/predictions.xlsx' # Name of the output Excel file
118
  main(input_path, model_path, output_file)