biengsen4real commited on
Commit
a47c735
·
verified ·
1 Parent(s): f0fc4a6

Delete file predict.ipynb

Browse files
Files changed (1) hide show
  1. predict.ipynb +0 -104
predict.ipynb DELETED
@@ -1,104 +0,0 @@
1
- import os
2
- import torch
3
- from transformers import ViTForImageClassification, ViTFeatureExtractor
4
- from PIL import Image
5
- import io
6
- import pandas as pd
7
-
8
- def load_model(model_path):
9
- """Load the pre-trained model and feature extractor."""
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
12
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13)
13
- model.load_state_dict(torch.load(model_path))
14
- model = model.to(device)
15
- model.eval() # Set the model to evaluation mode
16
- return model, feature_extractor, device
17
-
18
- def safe_load_image(path):
19
- """Safely load an image, handling possible errors."""
20
- try:
21
- with open(path, 'rb') as f:
22
- img = Image.open(io.BytesIO(f.read()))
23
- img = img.convert('RGB')
24
- return img
25
- except Exception as e:
26
- print(f"Error loading image {path}: {e}")
27
- return None
28
-
29
- def predict_image_class(image_path, model, feature_extractor, device, class_names):
30
- """Predict the class of a given image."""
31
- img = safe_load_image(image_path)
32
- if img is None:
33
- return None, None
34
-
35
- # Preprocess the image
36
- inputs = feature_extractor(images=img, return_tensors="pt").to(device)
37
-
38
- # Make the prediction
39
- with torch.no_grad():
40
- outputs = model(**inputs).logits
41
- probabilities = torch.softmax(outputs, dim=1).cpu().numpy()[0] # Calculate probabilities
42
- predicted_class_idx = outputs.argmax(dim=1).item() # Get the predicted class index
43
- predicted_class = class_names[predicted_class_idx] # Get class name from index
44
-
45
- return predicted_class, probabilities
46
-
47
- def predict_images_in_folder(folder_path, model, feature_extractor, device, class_names):
48
- """Predict the class of each image in a folder."""
49
- results = []
50
- for filename in os.listdir(folder_path):
51
- if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
52
- image_path = os.path.join(folder_path, filename)
53
- predicted_class, probabilities = predict_image_class(image_path, model, feature_extractor, device, class_names)
54
- if predicted_class is not None:
55
- results.append({'Image Name': filename, 'Predicted Class': predicted_class, 'Probabilities': probabilities})
56
-
57
- return results
58
-
59
- def save_results_to_excel(results, output_file):
60
- """Save the prediction results to an Excel file."""
61
- # Flatten the probabilities array and create a DataFrame
62
- rows = []
63
- for result in results:
64
- for idx, prob in enumerate(result['Probabilities']):
65
- rows.append({
66
- 'Image Name': result['Image Name'],
67
- 'Predicted Class': result['Predicted Class'],
68
- 'Class Index': idx,
69
- 'Probability': prob
70
- })
71
-
72
- df = pd.DataFrame(rows)
73
-
74
- # Sort by probability in descending order
75
- df = df.sort_values(by='Probability', ascending=False)
76
-
77
- # Save to Excel
78
- df.to_excel(output_file, index=False)
79
-
80
- def main(input_path, model_path, output_file):
81
- """Main function to perform image classification prediction, handling single images or folders, and saving results to Excel."""
82
- class_names = ['anti_war_protest', 'combat', 'construction', 'fire', 'human_damage',
83
- 'humanitarian_aid', 'infrastructure', 'military_parade', 'military_vehicle',
84
- 'meeting', 'speech', 'refugee', 'victory']
85
-
86
- model, feature_extractor, device = load_model(model_path)
87
-
88
- if os.path.isdir(input_path):
89
- # If the input path is a folder, predict all images in that folder
90
- results = predict_images_in_folder(input_path, model, feature_extractor, device, class_names)
91
- save_results_to_excel(results, output_file)
92
- print(f'Prediction results saved to: {output_file}')
93
- elif os.path.isfile(input_path):
94
- # If the input path is a single image, make a direct prediction
95
- predicted_class, probabilities = predict_image_class(input_path, model, feature_extractor, device, class_names)
96
- print(f'The predicted class for image {os.path.basename(input_path)}: {predicted_class}')
97
- else:
98
- print('Invalid input path. Please provide a valid file or folder path.')
99
-
100
- # Example call
101
- input_path = '/path/to/your/image_or_folder' # Replace with your image or folder path
102
- model_path = '/kaggle/working/best_modelq.pth' # Replace with your model path
103
- output_file = 'predictions.xlsx' # Name of the output Excel file
104
- main(input_path, model_path, output_file)