St3w31's picture
Upload README.md with huggingface_hub
1c082de verified
metadata
license: mit
library_name: pytorch
tags:
  - bilstm
  - lstm
  - pytorch
  - text-classification
  - spam-detection
model_details:
  parameters: 4403585
task_categories:
  - text-classification
datasets:
  - ucirvine/sms_spam
language:
  - en

BiLSTM Text Classifier

Simple BiLSTM model PyTorch trained for SPAM detection on SMS Spam Collection
(Almeida, Tiago and Jos Hidalgo. 2011. SMS Spam Collection.
UCI Machine Learning Repository. https://doi.org/10.24432/C5CC84).

Important Notes

  • The model returns logits as output; to obtain probabilities, apply torch.sigmoid.
  • The model uses the bert-base-uncased tokenizer only for tokenization (the encoder is NOT BERT).
  • Number of parameters: ~4.4M

Files

  • BiLSTMClassifier.safetensors: trained weights
  • BiLSTMClassifier.py: model definition
  • config.json: hyperparameters

Usage

import json
import torch
from transformers import BertTokenizer
from safetensors.torch import load_file
from BiLSTMClassifier import BiLSTMClassifier

with open("config.json") as f:
    cfg = json.load(f)

model = BiLSTMClassifier(**cfg)

state_dict = load_file("BiLSTMClassifier.safetensors")
model.load_state_dict(state_dict)
model.eval()

sample_text = "URGENT HIRING! Earn $500/day working from home. No experience needed."

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(sample_text, return_tensors="pt")

logits = model(tokens["input_ids"])
prob = torch.sigmoid(logits)