| |
|
| | import torch, pickle, json, string, nltk |
| | from pathlib import Path |
| | from lstm_model import LSTMClassifier |
| |
|
| | PAD = 0 |
| | UNK = 1 |
| | ROOT = Path(__file__).resolve().parent |
| |
|
| | cfg = json.load(open(ROOT/'config.json')) |
| | vocab = pickle.load(open(ROOT/'vocab.pkl', 'rb')) |
| |
|
| | model = LSTMClassifier(**cfg).eval() |
| | model.load_state_dict(torch.load(ROOT/'pytorch_model.bin', map_location='cpu')) |
| |
|
| | nltk.download('stopwords', quiet=True) |
| | STOP = set(nltk.corpus.stopwords.words('english')) |
| | PUNC = str.maketrans('', '', string.punctuation) |
| |
|
| | def preprocess(text): |
| | text = text.lower().translate(PUNC) |
| | toks = [w for w in text.split() if w not in STOP] |
| | return toks[: cfg['pad_len']] |
| |
|
| | def encode(tokens): |
| | ids = [vocab.get(w, UNK) for w in tokens] |
| | ids += [PAD] * (cfg['pad_len'] - len(ids)) |
| | return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)]) |
| |
|
| | @torch.no_grad() |
| | def predict(text): |
| | x, length = encode(preprocess(text)) |
| | logit = model(x, length) |
| | prob = torch.sigmoid(logit).item() |
| | return prob |
| |
|