| """ | |
| 모델들 불러오는 모듈 | |
| """ | |
| import torch | |
| # from .load_model import KCSN | |
| # from .arguments import get_train_args | |
| # args = get_train_args() | |
| def load_ner(path ='model/NER.pth'): | |
| """ | |
| NER 모델 | |
| """ | |
| checkpoint = torch.load(path) | |
| model = checkpoint['model'] | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| return model, checkpoint | |
| # def load_fs(path = 'model/FS.pth'): | |
| # """ | |
| # Find Speaker 모델 | |
| # """ | |
| # model = KCSN(args) | |
| # checkpoint = torch.load(path) | |
| # model.load_state_dict(checkpoint['model_state_dict']) | |
| # return model, checkpoint | |