EdwardSamuel13's picture
Upload 19 files
a3355a3 verified
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
def load_model_and_processor(model_path, model_name_or_path, class_names):
"""
Loads the ViT model and processor.
"""
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(
model_path,
num_labels=len(class_names),
id2label={str(i): label for i, label in enumerate(class_names)},
label2id={label: i for i, label in enumerate(class_names)},
)
model.eval()
return model, processor
def predict(model, processor, img, device="cpu"):
"""
Runs inference on an image and returns logits, probabilities, and prediction.
"""
img = img.convert("RGB")
processed_input = processor(images=img, return_tensors="pt").to(device)
pixel_values = processed_input["pixel_values"].to(device)
with torch.no_grad():
outputs = model(pixel_values, output_attentions=True)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0].tolist()
prediction = torch.argmax(logits, dim=-1).item()
return outputs, processed_input, probabilities, prediction