biengsen4real commited on
Commit
03a544e
·
verified ·
1 Parent(s): 347252f

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +42 -99
predict.py CHANGED
@@ -1,118 +1,61 @@
1
  import os
2
  import torch
 
 
3
  from transformers import ViTForImageClassification, ViTFeatureExtractor
4
  from PIL import Image
5
- import io
6
- import pandas as pd
7
 
8
  def load_model(model_path):
9
- """Load the pre-trained model and feature extractor."""
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
12
-
13
- # Load the model with weights mapped to the appropriate device
14
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)
15
-
16
- # Load the state dict with map_location to handle devices properly
17
- state_dict = torch.load(model_path, map_location=device)
18
- model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore size mismatches
19
-
20
- model = model.to(device)
21
- model.eval() # Set the model to evaluation mode
22
- return model, feature_extractor, device
23
 
24
- def safe_load_image(path):
25
- """Safely load an image, handling potential errors."""
26
- try:
27
- with open(path, 'rb') as f:
28
- img = Image.open(io.BytesIO(f.read()))
29
- img = img.convert('RGB')
30
- img = img.resize((224, 224)) # Resize the image to (224, 224)
31
- return img
32
- except Exception as e:
33
- print(f"Error loading image {path}: {e}")
34
- return None
35
-
36
- def predict_image_class(image_path, model, feature_extractor, device, class_names):
37
- """Predict the class of a given image."""
38
- img = safe_load_image(image_path)
39
- if img is None:
40
- return None, None
41
-
42
- # Preprocess the image
43
- inputs = feature_extractor(images=img, return_tensors="pt").to(device)
44
 
45
- # Perform prediction
 
 
46
  with torch.no_grad():
47
- outputs = model(**inputs).logits
48
- probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
49
- predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
50
- predicted_class = class_names[predicted_class_idx] # Get the class name based on the index
51
-
52
- return predicted_class, probabilities
53
-
54
- def predict_images_in_folder(folder_path, model, feature_extractor, device, class_names):
55
- """Predict the class of each image in a folder."""
56
- results = []
57
- for filename in os.listdir(folder_path):
58
- if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
59
- image_path = os.path.join(folder_path, filename)
60
- predicted_class, probabilities = predict_image_class(image_path, model, feature_extractor, device, class_names)
61
- if predicted_class is not None:
62
- results.append({'Image Name': filename, 'Predicted Class': predicted_class, 'Probabilities': probabilities})
63
-
64
- return results
65
-
66
- def save_results_to_excel(results, output_file, class_names):
67
- """Save prediction results to an Excel file."""
68
- # Flatten probability array and create DataFrame
69
- rows = []
70
- for result in results:
71
- # Add each probability with corresponding class name
72
- for idx, prob in enumerate(result['Probabilities']):
73
- rows.append({
74
- 'Image Name': result['Image Name'],
75
- 'Predicted Class': result['Predicted Class'],
76
- 'Class': class_names[idx],
77
- 'Probability': prob
78
- })
79
-
80
- df = pd.DataFrame(rows)
81
-
82
- # Sort by probability in descending order
83
- df = df.sort_values(by='Probability', ascending=False)
84
-
85
- # Save to Excel
86
- df.to_excel(output_file, index=False)
87
- print(f'Results saved to {output_file}') # Confirm saving
88
 
89
  def main(input_path, model_path, output_file):
90
- """Main function to execute image classification predictions, processing single images or folders, and saving results to Excel."""
91
- class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
92
- 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
93
- 'meeting', 'speech', 'refugee', 'victory']
94
-
95
- model, feature_extractor, device = load_model(model_path)
96
 
97
  if os.path.isdir(input_path):
98
- # If the input path is a folder, predict all images in that folder
99
- results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
100
- if results:
101
- save_results_to_excel(results, output_file, class_names)
102
- else:
103
- print("No valid images found in the specified folder.")
104
- elif os.path.isfile(input_path):
105
- # If the input path is a single image, perform direct prediction
106
- predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
107
- if predicted_class is not None:
108
- print(f'Predicted class for image {os.path.basename(input_path)}: {predicted_class}')
109
- else:
110
- print("Image could not be processed.")
111
  else:
112
- print('Invalid input path. Please provide a valid file or folder path.')
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Example call
115
- input_path = '/content/ddd.jpg' # Replace with your image or folder path
116
  model_path = '/content/model.pth' # Replace with your model path
117
- output_file = '/content/predictions.xlsx' # Name of the output Excel file
118
  main(input_path, model_path, output_file)
 
1
  import os
2
  import torch
3
+ import pandas as pd
4
+ from torchvision import transforms
5
  from transformers import ViTForImageClassification, ViTFeatureExtractor
6
  from PIL import Image
 
 
7
 
8
  def load_model(model_path):
9
+ """Load the pre-trained model."""
 
 
 
 
10
  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)
11
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
12
+ model.eval()
13
+ return model
 
 
 
 
 
14
 
15
+ def preprocess_image(image_path):
16
+ """Preprocess the image for prediction."""
17
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
18
+ image = Image.open(image_path).convert("RGB")
19
+ image = feature_extractor(images=image, return_tensors="pt")["pixel_values"]
20
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def predict(model, image_path):
23
+ """Predict the class probabilities for an image."""
24
+ image = preprocess_image(image_path)
25
  with torch.no_grad():
26
+ outputs = model(image).logits
27
+ probabilities = torch.softmax(outputs, dim=1)
28
+ return probabilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def main(input_path, model_path, output_file):
31
+ """Main function to predict and save results to Excel."""
32
+ model = load_model(model_path)
33
+ results = []
 
 
 
34
 
35
  if os.path.isdir(input_path):
36
+ for img_name in os.listdir(input_path):
37
+ img_path = os.path.join(input_path, img_name)
38
+ if img_path.endswith(('.png', '.jpg', '.jpeg')): # Check for image file types
39
+ probs = predict(model, img_path).cpu().numpy()[0]
40
+ result = {"Image Name": img_name}
41
+ for i, prob in enumerate(probs):
42
+ result[f"Class {i} Probability"] = prob # Store probabilities
43
+ results.append(result)
 
 
 
 
 
44
  else:
45
+ # If a single image file is provided
46
+ probs = predict(model, input_path).cpu().numpy()[0]
47
+ result = {"Image Name": os.path.basename(input_path)}
48
+ for i, prob in enumerate(probs):
49
+ result[f"Class {i} Probability"] = prob # Store probabilities
50
+ results.append(result)
51
+
52
+ # Create DataFrame and save to Excel
53
+ df = pd.DataFrame(results)
54
+ df.to_excel(output_file, index=False)
55
+ print(f"Results saved to {output_file}")
56
 
57
  # Example call
58
+ input_path = '/content/ddd.jpg' # Replace with your image folder or single image path
59
  model_path = '/content/model.pth' # Replace with your model path
60
+ output_file = 'predictions.xlsx' # Name of the output Excel file
61
  main(input_path, model_path, output_file)