File size: 264 Bytes
ab8b628
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from transformers import ViTForImageClassification
import os

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
MODEL_DIR = os.path.join(ROOT_DIR, "model")


def load_model():
    return ViTForImageClassification.from_pretrained(MODEL_DIR)