File size: 436 Bytes
5aa6736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from transformers import AutoModel, AutoProcessor
from config.settings import MODEL_ID, TORCH_DTYPE, ATTN_IMPLEMENTATION, DEVICE

def load_model():
    model = AutoModel.from_pretrained(
        MODEL_ID,
        torch_dtype=TORCH_DTYPE,
        attn_implementation=ATTN_IMPLEMENTATION
    )
    model = model.to(DEVICE)
    model.eval()

    processor = AutoProcessor.from_pretrained(MODEL_ID)
    return model, processor