Spaces:
Runtime error
Runtime error
| import torch, pickle, json, nltk, string | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from lstm_model import LSTMClassifier | |
| # ββ download model artefacts from the model repo once ββββββββββββββββββ | |
| REPO_ID = "ecroatt/imdb-bilstm-sentiment" # β change to your handle | |
| for fname in ["config.json", "vocab.pkl", "pytorch_model.bin"]: | |
| hf_hub_download(repo_id=REPO_ID, | |
| filename=fname, | |
| local_dir=".", | |
| repo_type="model", | |
| force_download=False) | |
| # ββ load config, vocab, weights ββββββββββββββββββββββββββββββββββββββββ | |
| ROOT = Path(__file__).resolve().parent | |
| cfg = json.load(open(ROOT / "config.json")) | |
| vocab = pickle.load(open(ROOT / "vocab.pkl", "rb")) | |
| model = LSTMClassifier( | |
| vocab_size = cfg["vocab_size"], | |
| embed_dim = cfg["embed_dim"], | |
| hidden_dim = cfg["hidden_dim"], | |
| n_layers = cfg["n_layers"], | |
| bidirectional = cfg["bidirectional"] | |
| ).eval() | |
| model.load_state_dict(torch.load(ROOT / "pytorch_model.bin", map_location="cpu")) | |
| # ββ preprocessing helpers ββββββββββββββββββββββββββββββββββββββββββββββ | |
| nltk.download("stopwords", quiet=True) | |
| STOP = set(nltk.corpus.stopwords.words("english")) | |
| PUNC = str.maketrans("", "", string.punctuation) | |
| PAD_LEN = cfg["pad_len"] | |
| PAD, UNK = 0, 1 | |
| def preprocess(text: str): | |
| text = text.lower().translate(PUNC) | |
| toks = [w for w in text.split() if w not in STOP] | |
| return toks[:PAD_LEN] | |
| def encode(tokens): | |
| ids = [vocab.get(w, UNK) for w in tokens] | |
| ids += [PAD] * (PAD_LEN - len(ids)) | |
| return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)]) | |
| def predict(text: str) -> float: | |
| """Return probability (0-1) that the review is positive.""" | |
| x, length = encode(preprocess(text)) | |
| logit = model(x, length) | |
| return torch.sigmoid(logit).item() | |