BiLSTM Text Classifier

Simple BiLSTM model PyTorch trained for SPAM detection on SMS Spam Collection
(Almeida, Tiago and Jos Hidalgo. 2011. SMS Spam Collection.
UCI Machine Learning Repository. https://doi.org/10.24432/C5CC84).

Important Notes

  • The model returns logits as output; to obtain probabilities, apply torch.sigmoid.
  • The model uses the bert-base-uncased tokenizer only for tokenization (the encoder is NOT BERT).
  • Number of parameters: ~4.4M

Files

  • BiLSTMClassifier.safetensors: trained weights
  • BiLSTMClassifier.py: model definition
  • config.json: hyperparameters

Usage

import json
import torch
from transformers import BertTokenizer
from safetensors.torch import load_file
from BiLSTMClassifier import BiLSTMClassifier

with open("config.json") as f:
    cfg = json.load(f)

model = BiLSTMClassifier(**cfg)

state_dict = load_file("BiLSTMClassifier.safetensors")
model.load_state_dict(state_dict)
model.eval()

sample_text = "URGENT HIRING! Earn $500/day working from home. No experience needed."

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(sample_text, return_tensors="pt")

logits = model(tokens["input_ids"])
prob = torch.sigmoid(logits)
Downloads last month
10
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train St3w31/BiLSTMSpamClassifier