hw2-text-distilbert / README.md
Anyuhhh's picture
Create README.md
2942687 verified
---
language:
- en
license: apache-2.0
library_name: transformers
tags:
- text-classification
- distilbert
- fine-tuned
- pytorch
datasets:
- cassieli226/cities-text-dataset
base_model: distilbert-base-uncased
model-index:
- name: hw2-text-distilbert
results:
- task:
type: text-classification
name: Text Classification
dataset:
type: cassieli226/cities-text-dataset
name: Cities Text Dataset
split: test
metrics:
- type: accuracy
value: 99.5
name: Test Accuracy
- type: f1
value: 99.5
name: Test F1 Score (Macro)
---
# DistilBERT Text Classification Model
This model is a fine-tuned version of [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) for text classification tasks.
## Model Description
This model is a fine-tuned DistilBERT model for binary text classification, specifically designed to classify text as being related to either Pittsburgh or Shanghai cities. The model achieves excellent performance with 99.5% accuracy on the test set.
- **Model type:** Text Classification (Binary)
- **Language(s) (NLP):** English
- **Base model:** distilbert-base-uncased
- **Classes:** Pittsburgh, Shanghai
## Intended Uses & Limitations
### Intended Uses
- Binary text classification between Pittsburgh and Shanghai-related content
- City-based text categorization tasks
- Research and educational purposes in NLP and text classification
### Limitations
- Limited to English language text
- Performance may vary on out-of-domain data
- Maximum input length of 256 tokens due to truncation
## Training and Evaluation Data
### Training Data
- **Base dataset:** [cassieli226/cities-text-dataset](https://huggingface.co/datasets/cassieli226/cities-text-dataset)
- **Classes:** Pittsburgh (507 samples) and Shanghai (493 samples) in augmented dataset
- **Original dataset:** 100 samples (50 Pittsburgh, 50 Shanghai)
- **Data augmentation:** Applied to increase dataset size from 100 to 1000 samples
- **Train/Test Split:** 80/20 split (800 train, 200 test) with stratified sampling
- **External validation:** Original 100 samples used for additional validation
### Preprocessing
- Text tokenization using DistilBERT tokenizer
- Maximum sequence length: 256 tokens
- Truncation applied to longer sequences
## Training Procedure
### Training Hyperparameters
- **Learning rate:** 5e-5
- **Training batch size:** 16
- **Evaluation batch size:** 32
- **Number of epochs:** 4
- **Weight decay:** 0.01
- **Warmup ratio:** 0.1
- **LR scheduler:** Linear
- **Gradient accumulation steps:** 1
- **Mixed precision:** FP16 (if GPU available)
### Training Configuration
- **Optimizer:** AdamW (default)
- **Early stopping:** Enabled with patience of 2 epochs
- **Best model selection:** Based on F1 score (macro)
- **Evaluation strategy:** Every epoch
- **Save strategy:** Every epoch (best model only)
## Evaluation
### Metrics
The model was evaluated using:
- **Accuracy:** Overall classification accuracy
- **F1 Score (Macro):** Macro-averaged F1 score across all classes
- **Per-class accuracy:** Individual class performance metrics
### Results
- **Test Set Performance:**
- Accuracy: 99.5%
- F1 Score (Macro): 99.5%
- **External Validation:**
- Accuracy: 100.0%
- F1 Score (Macro): 100.0%
### Detailed Performance
- **Pittsburgh Class:** 99.01% accuracy (101 samples)
- **Shanghai Class:** 100.0% accuracy (99 samples)
- **Confusion Matrix:** Only 1 misclassification out of 200 test samples
## Usage
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load model and tokenizer
model_name = "Anyuhhh/hw2-text-distilbert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Example usage
text = "Your input text here"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1)
print(f"Predicted class: {predicted_class.item()}")