Spaces:
Runtime error
Runtime error
| import os, sys | |
| myPath = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, myPath + '/../') | |
| # ========== | |
| import torch | |
| from ercbcm.model_loader import load | |
| from ercbcm.ERCBCM import ERCBCM | |
| from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ========== | |
| model_for_predict = ERCBCM().to(device) | |
| load('ercbcm/model.pt', model_for_predict, device) | |
| def predict(sentence, name): | |
| label = torch.tensor([0]) | |
| label = label.type(torch.LongTensor) | |
| label = label.to(device) | |
| text = tokenizer.encode(normalize_v2(sentence, name)) | |
| text += [PAD_TOKEN_ID] * (128 - len(text)) | |
| text = torch.tensor([text]) | |
| text = text.type(torch.LongTensor) | |
| text = text.to(device) | |
| _, output = model_for_predict(text, label) | |
| pred = torch.argmax(output, 1).tolist()[0] | |
| return 'CALLING' if pred == 1 else 'MENTIONING' |