BiLSTM Text Classifier
Simple BiLSTM model PyTorch trained for SPAM detection on SMS Spam Collection
(Almeida, Tiago and Jos Hidalgo. 2011. SMS Spam Collection.
UCI Machine Learning Repository. https://doi.org/10.24432/C5CC84).
Important Notes
- The model returns logits as output; to obtain probabilities, apply
torch.sigmoid. - The model uses the
bert-base-uncasedtokenizer only for tokenization (the encoder is NOT BERT). - Number of parameters: ~4.4M
Files
BiLSTMClassifier.safetensors: trained weightsBiLSTMClassifier.py: model definitionconfig.json: hyperparameters
Usage
import json
import torch
from transformers import BertTokenizer
from safetensors.torch import load_file
from BiLSTMClassifier import BiLSTMClassifier
with open("config.json") as f:
cfg = json.load(f)
model = BiLSTMClassifier(**cfg)
state_dict = load_file("BiLSTMClassifier.safetensors")
model.load_state_dict(state_dict)
model.eval()
sample_text = "URGENT HIRING! Earn $500/day working from home. No experience needed."
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(sample_text, return_tensors="pt")
logits = model(tokens["input_ids"])
prob = torch.sigmoid(logits)
- Downloads last month
- 10