Spaces:
Build error
Build error
| import torch | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| def load_model_and_processor(model_path, model_name_or_path, class_names): | |
| """ | |
| Loads the ViT model and processor. | |
| """ | |
| processor = ViTImageProcessor.from_pretrained(model_name_or_path) | |
| model = ViTForImageClassification.from_pretrained( | |
| model_path, | |
| num_labels=len(class_names), | |
| id2label={str(i): label for i, label in enumerate(class_names)}, | |
| label2id={label: i for i, label in enumerate(class_names)}, | |
| ) | |
| model.eval() | |
| return model, processor | |
| def predict(model, processor, img, device="cpu"): | |
| """ | |
| Runs inference on an image and returns logits, probabilities, and prediction. | |
| """ | |
| img = img.convert("RGB") | |
| processed_input = processor(images=img, return_tensors="pt").to(device) | |
| pixel_values = processed_input["pixel_values"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(pixel_values, output_attentions=True) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0].tolist() | |
| prediction = torch.argmax(logits, dim=-1).item() | |
| return outputs, processed_input, probabilities, prediction | |