|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- ucirvine/sms_spam |
|
|
language: |
|
|
- en |
|
|
- hi |
|
|
- te |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
base_model: |
|
|
- distilbert/distilbert-base-uncased |
|
|
tags: |
|
|
- text_classification |
|
|
- spam_detection |
|
|
- distilbert |
|
|
--- |
|
|
# Spam Detection using DistilBERT |
|
|
|
|
|
This model is a fine-tuned `distilbert-base-uncased` transformer for binary |
|
|
spam classification (spam vs ham). |
|
|
|
|
|
## Labels |
|
|
- 0 โ Ham |
|
|
- 1 โ Spam |
|
|
|
|
|
## Usage |
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("<your-username>/spam-detection-distilbert") |
|
|
model = AutoModelForSequenceClassification.from_pretrained("<your-username>/spam-detection-distilbert") |
|
|
|
|
|
inputs = tokenizer( |
|
|
"You won a free iPhone!", |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
prediction = torch.argmax(outputs.logits, dim=1).item() |
|
|
print("SPAM" if prediction == 1 else "HAM") |
|
|
``` |
|
|
|
|
|
|
|
|
## ๐ GitHub Repository |
|
|
|
|
|
Code for training and inference is available here: |
|
|
https://github.com/revanthreddy0906/spam-detection-distilbert.git |