|
|
import os |
|
|
import torch |
|
|
import pandas as pd |
|
|
from torchvision import transforms |
|
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
|
from PIL import Image |
|
|
|
|
|
def load_model(model_path): |
|
|
"""Load the pre-trained model.""" |
|
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=13, ignore_mismatched_sizes=True) |
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def preprocess_image(image_path): |
|
|
"""Preprocess the image for prediction.""" |
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
image = feature_extractor(images=image, return_tensors="pt")["pixel_values"] |
|
|
return image |
|
|
|
|
|
def predict(model, image_path): |
|
|
"""Predict the class probabilities for an image.""" |
|
|
image = preprocess_image(image_path) |
|
|
with torch.no_grad(): |
|
|
outputs = model(image).logits |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
return probabilities |
|
|
|
|
|
def main(input_path, model_path, output_file): |
|
|
"""Main function to predict and save results to Excel.""" |
|
|
model = load_model(model_path) |
|
|
results = [] |
|
|
|
|
|
if os.path.isdir(input_path): |
|
|
for img_name in os.listdir(input_path): |
|
|
img_path = os.path.join(input_path, img_name) |
|
|
if img_path.endswith(('.png', '.jpg', '.jpeg')): |
|
|
probs = predict(model, img_path).cpu().numpy()[0] |
|
|
result = {"Image Name": img_name} |
|
|
for i, prob in enumerate(probs): |
|
|
result[f"Class {i} Probability"] = prob |
|
|
results.append(result) |
|
|
else: |
|
|
|
|
|
probs = predict(model, input_path).cpu().numpy()[0] |
|
|
result = {"Image Name": os.path.basename(input_path)} |
|
|
for i, prob in enumerate(probs): |
|
|
result[f"Class {i} Probability"] = prob |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
df.to_excel(output_file, index=False) |
|
|
print(f"Results saved to {output_file}") |
|
|
|
|
|
|
|
|
input_path = '/content/ddd.jpg' |
|
|
model_path = '/content/model.pth' |
|
|
output_file = 'predictions.xlsx' |
|
|
main(input_path, model_path, output_file) |