|
|
--- |
|
|
license: mit |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- bilstm |
|
|
- lstm |
|
|
- pytorch |
|
|
- text-classification |
|
|
- spam-detection |
|
|
model_details: |
|
|
parameters: 4403585 |
|
|
task_categories: |
|
|
- text-classification |
|
|
datasets: |
|
|
- ucirvine/sms_spam |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# BiLSTM Text Classifier |
|
|
|
|
|
Simple BiLSTM model PyTorch trained for SPAM detection on SMS Spam Collection |
|
|
(Almeida, Tiago and Jos Hidalgo. 2011. *SMS Spam Collection*. |
|
|
UCI Machine Learning Repository. https://doi.org/10.24432/C5CC84). |
|
|
|
|
|
## Important Notes |
|
|
- The model returns **logits** as output; to obtain probabilities, apply `torch.sigmoid`. |
|
|
- The model uses the `bert-base-uncased` tokenizer **only for tokenization** (the encoder is NOT BERT). |
|
|
- Number of parameters: ~4.4M |
|
|
## Files |
|
|
- `BiLSTMClassifier.safetensors`: trained weights |
|
|
- `BiLSTMClassifier.py`: model definition |
|
|
- `config.json`: hyperparameters |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import json |
|
|
import torch |
|
|
from transformers import BertTokenizer |
|
|
from safetensors.torch import load_file |
|
|
from BiLSTMClassifier import BiLSTMClassifier |
|
|
|
|
|
with open("config.json") as f: |
|
|
cfg = json.load(f) |
|
|
|
|
|
model = BiLSTMClassifier(**cfg) |
|
|
|
|
|
state_dict = load_file("BiLSTMClassifier.safetensors") |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
sample_text = "URGENT HIRING! Earn $500/day working from home. No experience needed." |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
tokens = tokenizer(sample_text, return_tensors="pt") |
|
|
|
|
|
logits = model(tokens["input_ids"]) |
|
|
prob = torch.sigmoid(logits) |
|
|
|