| import torch | |
| from transformers import ViTFeatureExtractor | |
| from config import UNTRAINED | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(UNTRAINED) | |
| def predict(model, image): | |
| inputs = feature_extractor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # model predicts one of the 1000 ImageNet classes | |
| predicted_label = logits.argmax(-1).item() | |
| return model.config.id2label[str(predicted_label)] |