|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- text-classification |
|
|
- distilbert |
|
|
- fine-tuned |
|
|
- pytorch |
|
|
datasets: |
|
|
- cassieli226/cities-text-dataset |
|
|
base_model: distilbert-base-uncased |
|
|
|
|
|
model-index: |
|
|
- name: hw2-text-distilbert |
|
|
results: |
|
|
- task: |
|
|
type: text-classification |
|
|
name: Text Classification |
|
|
dataset: |
|
|
type: cassieli226/cities-text-dataset |
|
|
name: Cities Text Dataset |
|
|
split: test |
|
|
metrics: |
|
|
- type: accuracy |
|
|
value: 99.5 |
|
|
name: Test Accuracy |
|
|
- type: f1 |
|
|
value: 99.5 |
|
|
name: Test F1 Score (Macro) |
|
|
--- |
|
|
|
|
|
# DistilBERT Text Classification Model |
|
|
|
|
|
This model is a fine-tuned version of [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) for text classification tasks. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model is a fine-tuned DistilBERT model for binary text classification, specifically designed to classify text as being related to either Pittsburgh or Shanghai cities. The model achieves excellent performance with 99.5% accuracy on the test set. |
|
|
|
|
|
- **Model type:** Text Classification (Binary) |
|
|
- **Language(s) (NLP):** English |
|
|
- **Base model:** distilbert-base-uncased |
|
|
- **Classes:** Pittsburgh, Shanghai |
|
|
|
|
|
## Intended Uses & Limitations |
|
|
|
|
|
### Intended Uses |
|
|
- Binary text classification between Pittsburgh and Shanghai-related content |
|
|
- City-based text categorization tasks |
|
|
- Research and educational purposes in NLP and text classification |
|
|
|
|
|
### Limitations |
|
|
- Limited to English language text |
|
|
- Performance may vary on out-of-domain data |
|
|
- Maximum input length of 256 tokens due to truncation |
|
|
|
|
|
## Training and Evaluation Data |
|
|
|
|
|
### Training Data |
|
|
- **Base dataset:** [cassieli226/cities-text-dataset](https://huggingface.co/datasets/cassieli226/cities-text-dataset) |
|
|
- **Classes:** Pittsburgh (507 samples) and Shanghai (493 samples) in augmented dataset |
|
|
- **Original dataset:** 100 samples (50 Pittsburgh, 50 Shanghai) |
|
|
- **Data augmentation:** Applied to increase dataset size from 100 to 1000 samples |
|
|
- **Train/Test Split:** 80/20 split (800 train, 200 test) with stratified sampling |
|
|
- **External validation:** Original 100 samples used for additional validation |
|
|
|
|
|
### Preprocessing |
|
|
- Text tokenization using DistilBERT tokenizer |
|
|
- Maximum sequence length: 256 tokens |
|
|
- Truncation applied to longer sequences |
|
|
|
|
|
## Training Procedure |
|
|
|
|
|
### Training Hyperparameters |
|
|
- **Learning rate:** 5e-5 |
|
|
- **Training batch size:** 16 |
|
|
- **Evaluation batch size:** 32 |
|
|
- **Number of epochs:** 4 |
|
|
- **Weight decay:** 0.01 |
|
|
- **Warmup ratio:** 0.1 |
|
|
- **LR scheduler:** Linear |
|
|
- **Gradient accumulation steps:** 1 |
|
|
- **Mixed precision:** FP16 (if GPU available) |
|
|
|
|
|
### Training Configuration |
|
|
- **Optimizer:** AdamW (default) |
|
|
- **Early stopping:** Enabled with patience of 2 epochs |
|
|
- **Best model selection:** Based on F1 score (macro) |
|
|
- **Evaluation strategy:** Every epoch |
|
|
- **Save strategy:** Every epoch (best model only) |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
### Metrics |
|
|
The model was evaluated using: |
|
|
- **Accuracy:** Overall classification accuracy |
|
|
- **F1 Score (Macro):** Macro-averaged F1 score across all classes |
|
|
- **Per-class accuracy:** Individual class performance metrics |
|
|
|
|
|
### Results |
|
|
- **Test Set Performance:** |
|
|
- Accuracy: 99.5% |
|
|
- F1 Score (Macro): 99.5% |
|
|
- **External Validation:** |
|
|
- Accuracy: 100.0% |
|
|
- F1 Score (Macro): 100.0% |
|
|
|
|
|
### Detailed Performance |
|
|
- **Pittsburgh Class:** 99.01% accuracy (101 samples) |
|
|
- **Shanghai Class:** 100.0% accuracy (99 samples) |
|
|
- **Confusion Matrix:** Only 1 misclassification out of 200 test samples |
|
|
|
|
|
## Usage |
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model_name = "Anyuhhh/hw2-text-distilbert" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
# Example usage |
|
|
text = "Your input text here" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
predicted_class = torch.argmax(predictions, dim=-1) |
|
|
|
|
|
print(f"Predicted class: {predicted_class.item()}") |