SupremoUGH's picture
first commit
ab8b628 unverified
raw
history blame
264 Bytes
from transformers import ViTForImageClassification
import os
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
MODEL_DIR = os.path.join(ROOT_DIR, "model")
def load_model():
return ViTForImageClassification.from_pretrained(MODEL_DIR)