MiniLM Toxic Comment Classifier
Multi-label toxic comment classifier fine-tuned on the Jigsaw dataset. Detects 6 toxicity categories simultaneously. Ships as ONNX for fast CPU/GPU inference.
Base model: MiniLMv2-L6-H384-distilled-from-BERT-Large โ 23M parameters.
Labels
| ID | Label | Description |
|---|---|---|
| 0 | toxic |
Generally toxic content |
| 1 | severe_toxic |
Highly aggressive or abusive |
| 2 | obscene |
Vulgar or profane language |
| 3 | threat |
Threatening language |
| 4 | insult |
Insulting or demeaning content |
| 5 | identity_hate |
Hate speech targeting identity groups |
Test Set Performance
Evaluated on a held-out 10% stratified test split (15,958 samples). Thresholds calibrated on the eval set.
| Label | AUC-ROC | F1 | Precision | Recall | Threshold |
|---|---|---|---|---|---|
| toxic | 0.9852 | 0.8205 | 0.8208 | 0.8203 | 0.50 |
| severe_toxic | 0.9844 | 0.5140 | 0.4335 | 0.6312 | 0.21 |
| obscene | 0.9924 | 0.8237 | 0.7881 | 0.8627 | 0.44 |
| threat | 0.9454 | 0.1455 | 0.1026 | 0.2500 | 0.08 |
| insult | 0.9887 | 0.7400 | 0.6559 | 0.8490 | 0.32 |
| identity_hate | 0.9636 | 0.2877 | 0.2781 | 0.2979 | 0.20 |
AUC-ROC is threshold-independent. F1/Precision/Recall use calibrated per-label thresholds.
threatandidentity_hatehave low prevalence (~0.3% and ~0.9%) which limits F1.
Usage
ONNX (recommended)
import numpy as np
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
# -----------------------------
# Load model + tokenizer
# -----------------------------
model_id = "navodPeiris/minilm-toxic-classifier"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ORTModelForSequenceClassification.from_pretrained(model_id)
# -----------------------------
# Thresholds
# -----------------------------
thresholds = {
"toxic": 0.5,
"severe_toxic": 0.21,
"obscene": 0.44,
"threat": 0.08,
"insult": 0.32,
"identity_hate": 0.2
}
LABEL_COLS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
thresh_arr = np.array([thresholds[c] for c in LABEL_COLS])
# -----------------------------
# Classification function
# -----------------------------
def classify(texts: list[str]) -> list[dict]:
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=256,
return_tensors="np"
)
outputs = model(**enc)
logits = outputs.logits
if not isinstance(logits, np.ndarray):
logits = logits.numpy()
probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
results = []
for i in range(len(texts)):
max_idx = int(np.argmax(probs[i]))
results.append({
"label": LABEL_COLS[max_idx],
"prob": float(probs[i, max_idx])
})
return results
# -----------------------------
# Test
# -----------------------------
texts = ["You are an idiot", "i like to slap your face"]
results = classify(texts)
for text, result in zip(texts, results):
print(f"{text!r:40s} res={result}")
Repository Contents
navodPeiris/minilm-toxic-classifier/
โโโ onnx/
โ โโโ model.onnx # ONNX model (~91 MB), outputs logits [batch, 6]
โโโ config.json # model config with id2label mapping
โโโ tokenizer.json # fast tokenizer
โโโ tokenizer_config.json
โโโ thresholds.json # per-label calibrated thresholds
The ONNX model outputs raw logits [batch, 6]. Apply sigmoid to get probabilities, then compare against thresholds.json for binary flags.
Training Details
Data
- Source: Jigsaw Toxic Comment Classification Challenge โ 159,571 labeled comments
- Split: 80 / 10 / 10 using
MultilabelStratifiedShuffleSplit(stratified on all 6 labels simultaneously) - Hard negatives: Single characters, digits, and ~50 common short benign phrases (hi, hello, ok, ...) were added to the train/eval sets to prevent the model from misclassifying minimal inputs as toxic
Model Architecture
MiniLMv2-L6-H384-distilled-from-BERT-Large with a linear classification head added on top of [CLS]. The encoder weights are loaded from the published MLM checkpoint with LayerNorm parameter remapping (gammaโweight, betaโbias). The classification head is randomly initialized and trained from scratch.
Loss & Optimizer
- Loss: Column-wise mean BCE โ per-element BCE (no reduction), scaled by
1/num_classes, then averaged per column over the batch, summed across columns. Unknown labels (-1) are masked out. - Optimizer:
Adam(amsgrad=True, lr=3e-5, weight_decay=3e-6)โ flat learning rate, no LLRD - Scheduler: Cosine with 10% warmup
Hyperparameters
| Parameter | Value |
|---|---|
| Base model | MiniLMv2-L6-H384 |
| Max sequence length | 256 |
| Train batch size (effective) | 30 (10 ร 3 grad accum) |
| Learning rate | 3e-5 |
| Weight decay | 3e-6 |
| LR scheduler | Cosine |
| Warmup | 10% of epoch-3 steps |
| Max epochs | 100 |
| Early stopping patience | 3 epochs |
| Monitor | eval_loss |
| Epochs trained | 6 |
| Seed | 1234 |
Threshold Calibration
Per-label thresholds were swept from 0.02 to 0.95 (step 0.01) on the eval set to maximise per-label F1, then applied at test time. This is important for rare labels: threat (0.3% prevalence) uses threshold 0.08, and severe_toxic / identity_hate use 0.21 and 0.20 respectively.
Limitations
- Trained only on English Wikipedia/talk-page comments and not multilingual.
threatandidentity_hatehave very low prevalence in the training data, resulting in lower F1 despite high AUC-ROC.
License
Apache 2.0. The Jigsaw dataset is subject to its own terms of use.
- Downloads last month
- 14