| from transformers import ViTForImageClassification | |
| import os | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| MODEL_DIR = os.path.join(ROOT_DIR, "model") | |
| def load_model(): | |
| return ViTForImageClassification.from_pretrained(MODEL_DIR) | |