|
|
--- |
|
|
tags: |
|
|
- swahili |
|
|
- classification |
|
|
- multilabel |
|
|
- roberta |
|
|
- transformers |
|
|
- onnx |
|
|
- africa |
|
|
- nlp |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- sw |
|
|
- swa |
|
|
datasets: |
|
|
- custom |
|
|
metrics: |
|
|
- f1_score |
|
|
- precision |
|
|
- recall |
|
|
- hamming_loss |
|
|
pipeline_tag: text-classification |
|
|
task_categories: |
|
|
- text-classification |
|
|
task_ids: |
|
|
- multi-label-classification |
|
|
base_model: |
|
|
- benjamin/roberta-base-wechsel-swahili |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# Swahili Topic Classifier - Multi-label Classification |
|
|
|
|
|
## Model Details |
|
|
|
|
|
### Model Description |
|
|
A multi-label text classification model fine-tuned on RoBERTa-base Wechsel Swahili for classifying Swahili text into 8 predefined topics. The model can identify multiple applicable topics for a given text, providing confidence scores for each topic. |
|
|
|
|
|
- **Developed by**: NeboTech |
|
|
- **Model type**: Transformer-based (RoBERTa) |
|
|
- **Language(s)**: Swahili (Kiswahili) |
|
|
- **License**: Apache 2.0 |
|
|
- **Finetuned from**: [RoBERTa-base Wechsel Swahili](https://huggingface.co/roberta-base-wechsel-swahili) |
|
|
- **Model version**: v2.0 (Multi-label Classification) |
|
|
|
|
|
### Model Architecture |
|
|
- **Base Model**: RoBERTa-base Wechsel Swahili |
|
|
- **Task**: Multi-label Sequence Classification |
|
|
- **Problem Type**: `multi_label_classification` |
|
|
- **Number of Labels**: 8 |
|
|
- **Activation Function**: Sigmoid (for multi-label) |
|
|
- **Loss Function**: BCEWithLogitsLoss |
|
|
- **Output Format**: Binary vectors [batch_size, num_labels] |
|
|
|
|
|
### Model Variants |
|
|
- **v2.0** (Current): Multi-label classification - Returns multiple topics with confidence scores |
|
|
- **v1.0** (Legacy): Single-label classification - Returns single topic (available at `revision="v1.0-single-label"`) |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
### Primary Use Cases |
|
|
- **Content Classification**: Categorize Swahili text messages, reports, or documents |
|
|
- **Case Management**: Automatically tag and route cases to appropriate departments |
|
|
- **Content Moderation**: Identify topics requiring attention (e.g., health emergencies, violence) |
|
|
- **Data Analytics**: Analyze trends and patterns in Swahili text data |
|
|
- **Information Routing**: Direct messages to relevant stakeholders based on topics |
|
|
|
|
|
### Out-of-Scope Uses |
|
|
- **Not suitable for**: Languages other than Swahili |
|
|
- **Not suitable for**: Very short text (< 5 words) or very long text (> 512 tokens) |
|
|
- **Not suitable for**: Real-time critical decision making without human oversight |
|
|
- **Not suitable for**: Medical diagnosis or legal advice |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Training Data |
|
|
- **Dataset**: Custom Swahili text dataset |
|
|
- **Language**: Swahili (Kiswahili) |
|
|
- **Data Collection**: U-Report platform messages and related Swahili text |
|
|
- **Preprocessing**: Text cleaning, normalization, and tokenization |
|
|
- **Data Balance**: Dataset balanced across 8 topics |
|
|
|
|
|
### Training Procedure |
|
|
- **Training Type**: Fine-tuning from pre-trained RoBERTa-base Wechsel Swahili |
|
|
- **Optimizer**: AdamW |
|
|
- **Learning Rate**: 2e-5 |
|
|
- **Batch Size**: Variable (with gradient accumulation) |
|
|
- **Epochs**: 3 |
|
|
- **Gradient Accumulation**: 4 steps |
|
|
- **Weight Decay**: 0.01 |
|
|
- **Mixed Precision**: Enabled (FP16) |
|
|
- **Early Stopping**: Enabled (patience=2) |
|
|
|
|
|
### Training Hyperparametersl |
|
|
learning_rate: 2e-5 |
|
|
per_device_train_batch_size: 4 |
|
|
gradient_accumulation_steps: 4 |
|
|
num_train_epochs: 3 |
|
|
weight_decay: 0.01 |
|
|
warmup_steps: 0 |
|
|
max_grad_norm: 1.0 |
|
|
fp16: true## Evaluation |
|
|
|
|
|
### Testing Data, Factors & Metrics |
|
|
- **Evaluation Dataset**: Held-out test set from balanced dataset |
|
|
- **Evaluation Metrics**: |
|
|
- **F1 Score (Micro)**: Aggregated across all labels |
|
|
- **F1 Score (Macro)**: Average per-label F1 |
|
|
- **F1 Score (Samples)**: Average per-sample F1 |
|
|
- **Precision (Micro/Macro)**: Classification precision |
|
|
- **Recall (Micro/Macro)**: Classification recall |
|
|
- **Hamming Loss**: Fraction of incorrectly predicted labels |
|
|
- **Subset Accuracy**: Exact match accuracy |
|
|
|
|
|
### Results |
|
|
| Metric | Score | |
|
|
|--------|-------| |
|
|
| F1 Score (Micro) | 0.96 | |
|
|
| F1 Score (Macro) |0.96 | |
|
|
| F1 Score (Samples) |0.96 | |
|
|
| Precision (Micro) | 0.96 | |
|
|
| Recall (Micro) | 0.96 | |
|
|
| Hamming Loss | 0.009054 | |
|
|
| Subset Accuracy | 0.962 | |
|
|
|
|
|
## Model Performance Characteristics |
|
|
|
|
|
### Strengths |
|
|
- **Multi-label Capability**: Can identify multiple topics in a single text |
|
|
- **Confidence Scores**: Provides probability scores for each topic |
|
|
- **Swahili Language Support**: Specifically fine-tuned for Swahili text |
|
|
- **Efficient Inference**: ONNX format available for fast CPU inference |
|
|
- **Balanced Performance**: Trained on balanced dataset across all topics |
|
|
|
|
|
### Limitations |
|
|
- **Language Specific**: Only works with Swahili text |
|
|
- **Topic Coverage**: Limited to 8 predefined topics |
|
|
- **Context Dependency**: Performance may vary with text length and context |
|
|
- **Dialect Variations**: May not handle all Swahili dialects equally well |
|
|
- **Threshold Sensitivity**: Requires careful threshold tuning for optimal performance |
|
|
|
|
|
### Known Biases |
|
|
- **Training Data Bias**: Model reflects biases present in training data |
|
|
- **Geographic Bias**: May perform better on texts from regions in training data |
|
|
- **Topic Imbalance**: Some topics may have better representation in training data |
|
|
- **Cultural Context**: May not capture all cultural nuances in Swahili communication |
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
|
|
### Using Transformers (PyTorch) |
|
|
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
# Load model |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
"NeboTech/swahili-text-classifier", |
|
|
problem_type="multi_label_classification" # CRITICAL for multi-label |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("NeboTech/swahili-text-classifier") |
|
|
|
|
|
# Prepare input |
|
|
text = "Nataka kujua dalili za COVID-19 na jinsi ya kujilinda" |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256) |
|
|
|
|
|
# Get predictions |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits # Shape: [1, 8] |
|
|
|
|
|
# Apply sigmoid for multi-label |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
# Apply threshold |
|
|
threshold = 0.5 |
|
|
predictions = (probs > threshold).float() |
|
|
|
|
|
# Get applicable topics |
|
|
applicable_topics = torch.where(predictions[0] == 1)[0].tolist() |
|
|
print(f"Applicable topics: {applicable_topics}") |
|
|
print(f"Probabilities: {probs[0].tolist()}")### Using ONNX Runtime |
|
|
|
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("NeboTech/swahili-text-classifier") |
|
|
|
|
|
# Load ONNX model |
|
|
session = ort.InferenceSession("swahili_classifier.onnx") |
|
|
|
|
|
# Prepare input |
|
|
text = "Nataka kujua dalili za COVID-19" |
|
|
inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=256) |
|
|
|
|
|
# Run inference |
|
|
outputs = session.run( |
|
|
None, |
|
|
{ |
|
|
"input_ids": inputs["input_ids"].astype(np.int64), |
|
|
"attention_mask": inputs["attention_mask"].astype(np.int64) |
|
|
} |
|
|
) |
|
|
|
|
|
logits = outputs[0] # Shape: [1, 8] |
|
|
|
|
|
# Apply sigmoid |
|
|
probs = 1 / (1 + np.exp(-logits)) |
|
|
|
|
|
# Apply threshold |
|
|
threshold = 0.5 |
|
|
predictions = (probs > threshold).astype(float) |
|
|
|
|
|
# Get topics |
|
|
applicable_topics = np.where(predictions[0] == 1)[0] |
|
|
print(f"Applicable topics: {applicable_topics}")## Topics (Label Mapping) |
|
|
|
|
|
| ID | Topic | Description | |
|
|
|----|-------|-------------| |
|
|
| 0 | COVID | COVID-19 related topics, symptoms, prevention | |
|
|
| 1 | EDUCATION | Educational content, school-related topics | |
|
|
| 2 | HEALTH | General health topics, medical information | |
|
|
| 3 | HIV/AIDS | HIV/AIDS related information and support | |
|
|
| 4 | MENSTRUAL HYGIENE | Menstrual health and hygiene topics | |
|
|
| 5 | NUTRITION | Nutrition, food, and dietary information | |
|
|
| 6 | U-REPORT | U-Report platform related content | |
|
|
| 7 | VIOLENCE AGAINST CHILDREN | Child protection and violence prevention | |
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
### Ethical Use |
|
|
- **Human Oversight**: Always include human review for critical decisions |
|
|
- **Privacy**: Respect user privacy when processing text data |
|
|
- **Transparency**: Inform users when automated classification is used |
|
|
- **Fairness**: Monitor for biased outcomes across different user groups |
|
|
|
|
|
### Potential Risks |
|
|
- **Misclassification**: Incorrect topic assignment could misroute important messages |
|
|
- **False Positives/Negatives**: May miss urgent cases or flag non-urgent content |
|
|
- **Privacy Concerns**: Processing sensitive health and personal information |
|
|
- **Cultural Sensitivity**: May not fully capture cultural context and nuances |
|
|
|
|
|
### Recommendations |
|
|
- **Regular Monitoring**: Continuously monitor model performance in production |
|
|
- **Human Review**: Implement human review for high-stakes classifications |
|
|
- **Feedback Loop**: Collect and incorporate user feedback for improvements |
|
|
- **Bias Auditing**: Regularly audit for biases and fairness issues |
|
|
- **Threshold Tuning**: Adjust thresholds based on use case requirements |
|
|
|
|
|
## Citation |
|
|
|
|
|
@misc{swahili-topic-classifier-multilabel, |
|
|
title={Swahili Topic Classifier - Multi-label Classification}, |
|
|
author={NeboTech}, |
|
|
year={2024}, |
|
|
publisher={Hugging Face}, |
|
|
howpublished={\\url{https://huggingface.co/NeboTech/swahili-text-classifier}}, |
|
|
note={Version 2.0 - Multi-label Classification} |
|
|
}## Additional Information |
|
|
|
|
|
### Model Files |
|
|
- `config.json`: Model configuration |
|
|
- `pytorch_model.bin` or `model.safetensors`: Model weights |
|
|
- `tokenizer.json`: Tokenizer model |
|
|
- `tokenizer_config.json`: Tokenizer configuration |
|
|
- `vocab.json`, `merges.txt`: Vocabulary files |
|
|
- `swahili_classifier.onnx`: ONNX model (separate repository) |
|
|
|
|
|
### Version History |
|
|
- **v2.0** (Current): Multi-label classification with sigmoid activation |
|
|
- **v1.0** (Legacy): Single-label classification with softmax activation |
|
|
|
|
|
### Contact |
|
|
For questions, issues, or contributions |