nevreal's picture
Upload Complited files
ecfa0da verified
import sys
import os.path
import torch
code_path = os.path.dirname(os.path.abspath(__file__)) + "/"
sys.path.append(code_path)
import yaml
from ml_collections import ConfigDict
torch.set_float32_matmul_precision("medium")
def get_model(
config_path,
weights_path,
device,
):
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
f = open(config_path)
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
f.close()
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
model.load_state_dict(d)
model.to(device)
return model, config