import torch from transformers import ViTForImageClassification, ViTImageProcessor def load_model_and_processor(model_path, model_name_or_path, class_names): """ Loads the ViT model and processor. """ processor = ViTImageProcessor.from_pretrained(model_name_or_path) model = ViTForImageClassification.from_pretrained( model_path, num_labels=len(class_names), id2label={str(i): label for i, label in enumerate(class_names)}, label2id={label: i for i, label in enumerate(class_names)}, ) model.eval() return model, processor def predict(model, processor, img, device="cpu"): """ Runs inference on an image and returns logits, probabilities, and prediction. """ img = img.convert("RGB") processed_input = processor(images=img, return_tensors="pt").to(device) pixel_values = processed_input["pixel_values"].to(device) with torch.no_grad(): outputs = model(pixel_values, output_attentions=True) logits = outputs.logits probabilities = torch.softmax(logits, dim=1)[0].tolist() prediction = torch.argmax(logits, dim=-1).item() return outputs, processed_input, probabilities, prediction