imdb-bilstm-demo / inference.py
ecroatt's picture
Update inference.py
ffe93a3 verified
import torch, pickle, json, nltk, string
from pathlib import Path
from huggingface_hub import hf_hub_download
from lstm_model import LSTMClassifier
# ── download model artefacts from the model repo once ──────────────────
REPO_ID = "ecroatt/imdb-bilstm-sentiment" # ← change to your handle
for fname in ["config.json", "vocab.pkl", "pytorch_model.bin"]:
hf_hub_download(repo_id=REPO_ID,
filename=fname,
local_dir=".",
repo_type="model",
force_download=False)
# ── load config, vocab, weights ────────────────────────────────────────
ROOT = Path(__file__).resolve().parent
cfg = json.load(open(ROOT / "config.json"))
vocab = pickle.load(open(ROOT / "vocab.pkl", "rb"))
model = LSTMClassifier(
vocab_size = cfg["vocab_size"],
embed_dim = cfg["embed_dim"],
hidden_dim = cfg["hidden_dim"],
n_layers = cfg["n_layers"],
bidirectional = cfg["bidirectional"]
).eval()
model.load_state_dict(torch.load(ROOT / "pytorch_model.bin", map_location="cpu"))
# ── preprocessing helpers ──────────────────────────────────────────────
nltk.download("stopwords", quiet=True)
STOP = set(nltk.corpus.stopwords.words("english"))
PUNC = str.maketrans("", "", string.punctuation)
PAD_LEN = cfg["pad_len"]
PAD, UNK = 0, 1
def preprocess(text: str):
text = text.lower().translate(PUNC)
toks = [w for w in text.split() if w not in STOP]
return toks[:PAD_LEN]
def encode(tokens):
ids = [vocab.get(w, UNK) for w in tokens]
ids += [PAD] * (PAD_LEN - len(ids))
return torch.tensor(ids).unsqueeze(0), torch.tensor([len(tokens)])
@torch.no_grad()
def predict(text: str) -> float:
"""Return probability (0-1) that the review is positive."""
x, length = encode(preprocess(text))
logit = model(x, length)
return torch.sigmoid(logit).item()