File size: 238 Bytes
15c5ffb
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import torch
from model.model import BAPULM

def load_model(checkpoint_path, device):
    model = BAPULM()
    model.to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()
    return model