event_detect / predict.py
biengsen4real's picture
Update predict.py
03a544e verified
import os
import torch
import pandas as pd
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
def load_model(model_path):
"""Load the pre-trained model."""
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
model.eval()
return model
def preprocess_image(image_path):
"""Preprocess the image for prediction."""
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
image = Image.open(image_path).convert("RGB")
image = feature_extractor(images=image, return_tensors="pt")["pixel_values"]
return image
def predict(model, image_path):
"""Predict the class probabilities for an image."""
image = preprocess_image(image_path)
with torch.no_grad():
outputs = model(image).logits
probabilities = torch.softmax(outputs, dim=1)
return probabilities
def main(input_path, model_path, output_file):
"""Main function to predict and save results to Excel."""
model = load_model(model_path)
results = []
if os.path.isdir(input_path):
for img_name in os.listdir(input_path):
img_path = os.path.join(input_path, img_name)
if img_path.endswith(('.png', '.jpg', '.jpeg')): # Check for image file types
probs = predict(model, img_path).cpu().numpy()[0]
result = {"Image Name": img_name}
for i, prob in enumerate(probs):
result[f"Class {i} Probability"] = prob # Store probabilities
results.append(result)
else:
# If a single image file is provided
probs = predict(model, input_path).cpu().numpy()[0]
result = {"Image Name": os.path.basename(input_path)}
for i, prob in enumerate(probs):
result[f"Class {i} Probability"] = prob # Store probabilities
results.append(result)
# Create DataFrame and save to Excel
df = pd.DataFrame(results)
df.to_excel(output_file, index=False)
print(f"Results saved to {output_file}")
# Example call
input_path = '/content/ddd.jpg' # Replace with your image folder or single image path
model_path = '/content/model.pth' # Replace with your model path
output_file = 'predictions.xlsx' # Name of the output Excel file
main(input_path, model_path, output_file)