Add text register FastText classifier with training scripts
Browse files- README.md +168 -0
- register_fasttext.bin +3 -0
- register_fasttext_q.bin +3 -0
- scripts/predict.py +114 -0
- scripts/prepare_data.py +173 -0
- scripts/train.py +91 -0
README.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text Register FastText Classifier
|
| 2 |
+
|
| 3 |
+
A FastText classifier that detects the **communicative register** (text type) of any English text at ~500k predictions/sec on CPU.
|
| 4 |
+
|
| 5 |
+
## Labels
|
| 6 |
+
|
| 7 |
+
| Code | Register | Description | Example |
|
| 8 |
+
|------|----------|-------------|---------|
|
| 9 |
+
| `IN` | Informational | Factual, encyclopedic, descriptive | Wikipedia articles, reports |
|
| 10 |
+
| `NA` | Narrative | Story-like, temporal sequence of events | News stories, fiction, blog posts |
|
| 11 |
+
| `OP` | Opinion | Subjective evaluation, personal views | Reviews, editorials, comments |
|
| 12 |
+
| `IP` | Persuasion | Attempts to convince or sell | Marketing copy, ads, fundraising |
|
| 13 |
+
| `HI` | HowTo | Instructions, procedures, recipes | Tutorials, manuals, FAQs |
|
| 14 |
+
| `ID` | Discussion | Interactive, forum-style dialogue | Forum threads, Q&A, comments |
|
| 15 |
+
| `SP` | Spoken | Transcribed or spoken-style text | Interviews, podcasts, speeches |
|
| 16 |
+
| `LY` | Lyrical | Poetic, artistic, song-like | Poetry, song lyrics, creative prose |
|
| 17 |
+
|
| 18 |
+
Based on the Biber & Egbert (2018) register taxonomy. Multi-label supported (a text can be both Informational and Narrative).
|
| 19 |
+
|
| 20 |
+
## Quick Start
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
import fasttext
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
# Download model (quantized, 151 MB)
|
| 27 |
+
model_path = hf_hub_download(
|
| 28 |
+
"oneryalcin/text-register-fasttext-classifier",
|
| 29 |
+
"register_fasttext_q.bin"
|
| 30 |
+
)
|
| 31 |
+
model = fasttext.load_model(model_path)
|
| 32 |
+
|
| 33 |
+
# Predict
|
| 34 |
+
labels, probs = model.predict("Buy now and save 50%! Limited time offer!", k=3)
|
| 35 |
+
# -> [('__label__IP', 1.0), ...] # IP = Persuasion
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
> **Note**: If you get a numpy error, pin `numpy<2`: `pip install "numpy<2"`
|
| 39 |
+
|
| 40 |
+
## Performance
|
| 41 |
+
|
| 42 |
+
Trained on 10 English shards from [TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar) (~1.9M documents), balanced via oversampling/undersampling to median class size.
|
| 43 |
+
|
| 44 |
+
### Overall Metrics
|
| 45 |
+
|
| 46 |
+
| Metric | Full Model | Quantized |
|
| 47 |
+
|--------|-----------|-----------|
|
| 48 |
+
| Precision@1 | 0.831 | 0.796 |
|
| 49 |
+
| Recall@1 | 0.759 | 0.727 |
|
| 50 |
+
| Precision@2 | 0.491 | — |
|
| 51 |
+
| Recall@2 | 0.898 | — |
|
| 52 |
+
| Speed | ~500k pred/s | ~500k pred/s |
|
| 53 |
+
| Size | 1.1 GB | 151 MB |
|
| 54 |
+
|
| 55 |
+
### Per-Class F1 (threshold=0.3, k=2)
|
| 56 |
+
|
| 57 |
+
| Register | Precision | Recall | F1 | Test Support |
|
| 58 |
+
|----------|-----------|--------|-----|-------------|
|
| 59 |
+
| Informational | 0.910 | 0.666 | 0.769 | 108,672 |
|
| 60 |
+
| Narrative | 0.764 | 0.766 | 0.765 | 44,238 |
|
| 61 |
+
| Discussion | 0.640 | 0.774 | 0.701 | 7,420 |
|
| 62 |
+
| Persuasion | 0.553 | 0.794 | 0.652 | 19,193 |
|
| 63 |
+
| Opinion | 0.567 | 0.736 | 0.640 | 20,014 |
|
| 64 |
+
| HowTo | 0.515 | 0.766 | 0.616 | 7,281 |
|
| 65 |
+
| Spoken | 0.551 | 0.513 | 0.531 | 831 |
|
| 66 |
+
| Lyrical | 0.657 | 0.442 | 0.529 | 251 |
|
| 67 |
+
|
| 68 |
+
### Example Predictions
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
"The company reported revenue of $4.2 billion..." -> Informational (1.00), Narrative (0.99)
|
| 72 |
+
"Once upon a time in a small village..." -> Narrative
|
| 73 |
+
"I honestly think this movie is terrible..." -> Opinion (1.00)
|
| 74 |
+
"To install the package, first run pip install..." -> HowTo (1.00)
|
| 75 |
+
"Buy now and save 50%! Limited time offer..." -> Persuasion (1.00)
|
| 76 |
+
"So like, I was telling her yesterday..." -> Spoken (1.00)
|
| 77 |
+
"I've been walking these streets alone..." -> Lyrical (1.00)
|
| 78 |
+
"Hey everyone! What do you think about..." -> Discussion (1.00)
|
| 79 |
+
"Introducing the revolutionary SkinGlow Pro..." -> Persuasion (1.00)
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Use Cases
|
| 83 |
+
|
| 84 |
+
- **Data curation**: Filter pretraining corpora by register (e.g., keep only Informational + HowTo)
|
| 85 |
+
- **Content routing**: Route incoming text to different processing pipelines
|
| 86 |
+
- **Boilerplate removal**: Flag Persuasion/Marketing text in document corpora
|
| 87 |
+
- **Signal extraction**: Identify which paragraphs in a document carry factual vs opinion content
|
| 88 |
+
- **RAG preprocessing**: Score chunks by register before feeding to LLMs
|
| 89 |
+
|
| 90 |
+
## Reproduce from Scratch
|
| 91 |
+
|
| 92 |
+
### 1. Download data
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
pip install huggingface_hub
|
| 96 |
+
|
| 97 |
+
# Download 10 English shards (~4 GB)
|
| 98 |
+
for i in $(seq 0 9); do
|
| 99 |
+
hf download TurkuNLP/register_oscar \
|
| 100 |
+
$(printf "en/en_%05d.jsonl.gz" $i) \
|
| 101 |
+
--repo-type dataset --local-dir ./data
|
| 102 |
+
done
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### 2. Prepare balanced training data
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
python scripts/prepare_data.py --data-dir ./data/en --output-dir ./prepared
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### 3. Train
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
pip install fasttext-wheel "numpy<2"
|
| 115 |
+
python scripts/train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### 4. Predict
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
# Interactive
|
| 122 |
+
python scripts/predict.py --model ./model/register_fasttext_q.bin
|
| 123 |
+
|
| 124 |
+
# Single text
|
| 125 |
+
python scripts/predict.py --model ./model/register_fasttext_q.bin --text "Buy now! 50% off!"
|
| 126 |
+
|
| 127 |
+
# Batch
|
| 128 |
+
python scripts/predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output out.jsonl
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Training Details
|
| 132 |
+
|
| 133 |
+
- **Source data**: [TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar) (English, 10 shards, ~1.9M labeled documents)
|
| 134 |
+
- **Balancing**: Minority classes oversampled, majority classes undersampled to median class size (~129k per class)
|
| 135 |
+
- **Architecture**: FastText supervised with bigrams, 100-dim embeddings, one-vs-all loss
|
| 136 |
+
- **Hyperparameters**: lr=0.5, epoch=25, wordNgrams=2, dim=100, loss=ova, bucket=2M
|
| 137 |
+
- **Text preprocessing**: Whitespace collapsed, truncated to 500 words
|
| 138 |
+
|
| 139 |
+
## Limitations
|
| 140 |
+
|
| 141 |
+
- **Spoken & Lyrical** classes have lower F1 (~0.53) due to limited unique training data even after oversampling
|
| 142 |
+
- Trained on web text only — may not generalize well to domain-specific text (legal, medical)
|
| 143 |
+
- Bag-of-words model — does not understand word order or deep semantics
|
| 144 |
+
- English only (the source dataset has other languages that could be used for multilingual training)
|
| 145 |
+
|
| 146 |
+
## Citation
|
| 147 |
+
|
| 148 |
+
If you use this model, please cite the source dataset:
|
| 149 |
+
|
| 150 |
+
```bibtex
|
| 151 |
+
@inproceedings{register_oscar,
|
| 152 |
+
title={Multilingual register classification on the full OSCAR data},
|
| 153 |
+
author={R{\"o}nnqvist, Samuel and others},
|
| 154 |
+
year={2023},
|
| 155 |
+
note={TurkuNLP, University of Turku}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
@article{biber2018register,
|
| 159 |
+
title={Register as a predictor of linguistic variation},
|
| 160 |
+
author={Biber, Douglas and Egbert, Jesse},
|
| 161 |
+
journal={Corpus Linguistics and Linguistic Theory},
|
| 162 |
+
year={2018}
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## License
|
| 167 |
+
|
| 168 |
+
The model weights inherit the license of the source dataset ([TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar)). Scripts are released under MIT.
|
register_fasttext.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e76a01fa9946bd26ab2eeeea1842ff643cc486634c0f4db4dbe85b6b7c78017
|
| 3 |
+
size 1156314566
|
register_fasttext_q.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b1d55be8d490dbbcb17773592e367d58fb857dfe0a603c322246aef6855de86
|
| 3 |
+
size 158362937
|
scripts/predict.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Predict text register using the trained FastText model.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Interactive mode
|
| 6 |
+
python predict.py --model ./model/register_fasttext_q.bin
|
| 7 |
+
|
| 8 |
+
# Single text
|
| 9 |
+
python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!"
|
| 10 |
+
|
| 11 |
+
# File mode (one text per line)
|
| 12 |
+
python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import fasttext
|
| 16 |
+
import json
|
| 17 |
+
import sys
|
| 18 |
+
import argparse
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
REGISTER_LABELS = {
|
| 23 |
+
"IN": "Informational",
|
| 24 |
+
"NA": "Narrative",
|
| 25 |
+
"OP": "Opinion",
|
| 26 |
+
"IP": "Persuasion",
|
| 27 |
+
"HI": "HowTo",
|
| 28 |
+
"ID": "Discussion",
|
| 29 |
+
"SP": "Spoken",
|
| 30 |
+
"LY": "Lyrical",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def predict_one(model, text: str, k: int = 3, threshold: float = 0.1):
|
| 35 |
+
"""Predict register labels for a single text."""
|
| 36 |
+
labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold)
|
| 37 |
+
results = []
|
| 38 |
+
for label, prob in zip(labels, probs):
|
| 39 |
+
code = label.replace("__label__", "")
|
| 40 |
+
results.append({
|
| 41 |
+
"label": code,
|
| 42 |
+
"name": REGISTER_LABELS.get(code, code),
|
| 43 |
+
"score": round(float(prob), 4),
|
| 44 |
+
})
|
| 45 |
+
return results
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
parser = argparse.ArgumentParser(description="Predict text register")
|
| 50 |
+
parser.add_argument("--model", required=True, help="Path to FastText .bin model")
|
| 51 |
+
parser.add_argument("--text", help="Single text to classify")
|
| 52 |
+
parser.add_argument("--input", help="Input file (one text per line)")
|
| 53 |
+
parser.add_argument("--output", help="Output JSONL file")
|
| 54 |
+
parser.add_argument("--k", type=int, default=3, help="Top-k labels to return")
|
| 55 |
+
parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold")
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
# Suppress load warning
|
| 59 |
+
try:
|
| 60 |
+
fasttext.FastText.eprint = lambda x: None
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
model = fasttext.load_model(args.model)
|
| 65 |
+
|
| 66 |
+
if args.text:
|
| 67 |
+
# Single prediction
|
| 68 |
+
results = predict_one(model, args.text, args.k, args.threshold)
|
| 69 |
+
for r in results:
|
| 70 |
+
print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
|
| 71 |
+
|
| 72 |
+
elif args.input:
|
| 73 |
+
# Batch mode
|
| 74 |
+
out_f = open(args.output, "w") if args.output else sys.stdout
|
| 75 |
+
count = 0
|
| 76 |
+
start = time.time()
|
| 77 |
+
|
| 78 |
+
with open(args.input) as f:
|
| 79 |
+
for line in f:
|
| 80 |
+
text = line.strip()
|
| 81 |
+
if not text:
|
| 82 |
+
continue
|
| 83 |
+
results = predict_one(model, text, args.k, args.threshold)
|
| 84 |
+
record = {"text": text[:200], "predictions": results}
|
| 85 |
+
out_f.write(json.dumps(record) + "\n")
|
| 86 |
+
count += 1
|
| 87 |
+
|
| 88 |
+
elapsed = time.time() - start
|
| 89 |
+
if args.output:
|
| 90 |
+
out_f.close()
|
| 91 |
+
print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr)
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
# Interactive mode
|
| 95 |
+
print("Text Register Classifier (type 'quit' to exit)")
|
| 96 |
+
print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}")
|
| 97 |
+
print()
|
| 98 |
+
while True:
|
| 99 |
+
try:
|
| 100 |
+
text = input("> ").strip()
|
| 101 |
+
except (EOFError, KeyboardInterrupt):
|
| 102 |
+
break
|
| 103 |
+
if text.lower() in ("quit", "exit", "q"):
|
| 104 |
+
break
|
| 105 |
+
if not text:
|
| 106 |
+
continue
|
| 107 |
+
results = predict_one(model, text, args.k, args.threshold)
|
| 108 |
+
for r in results:
|
| 109 |
+
print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
|
| 110 |
+
print()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
scripts/prepare_data.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prepare balanced FastText training data from TurkuNLP/register_oscar dataset.
|
| 3 |
+
|
| 4 |
+
Downloads English shards, extracts labeled documents, and creates a balanced
|
| 5 |
+
training set by oversampling minority classes and undersampling majority classes
|
| 6 |
+
to the median class size.
|
| 7 |
+
|
| 8 |
+
Requirements:
|
| 9 |
+
pip install huggingface_hub
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
# Download shards first:
|
| 13 |
+
for i in $(seq 0 9); do
|
| 14 |
+
hf download TurkuNLP/register_oscar \
|
| 15 |
+
$(printf "en/en_%05d.jsonl.gz" $i) \
|
| 16 |
+
--repo-type dataset --local-dir ./data
|
| 17 |
+
done
|
| 18 |
+
|
| 19 |
+
# Then run:
|
| 20 |
+
python prepare_data.py --data-dir ./data/en --output-dir ./prepared
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import gzip
|
| 25 |
+
import re
|
| 26 |
+
import random
|
| 27 |
+
import glob
|
| 28 |
+
import argparse
|
| 29 |
+
from collections import Counter, defaultdict
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
REGISTER_LABELS = {
|
| 34 |
+
"IN": "Informational",
|
| 35 |
+
"NA": "Narrative",
|
| 36 |
+
"OP": "Opinion",
|
| 37 |
+
"IP": "Persuasion",
|
| 38 |
+
"HI": "HowTo",
|
| 39 |
+
"ID": "Discussion",
|
| 40 |
+
"SP": "Spoken",
|
| 41 |
+
"LY": "Lyrical",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def clean_text(text: str, max_words: int = 500) -> str:
|
| 46 |
+
"""Collapse whitespace and truncate to max_words."""
|
| 47 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 48 |
+
words = text.split()[:max_words]
|
| 49 |
+
return " ".join(words)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
parser = argparse.ArgumentParser(description="Prepare balanced FastText training data")
|
| 54 |
+
parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards")
|
| 55 |
+
parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files")
|
| 56 |
+
parser.add_argument("--max-words", type=int, default=500, help="Max words per document")
|
| 57 |
+
parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep")
|
| 58 |
+
parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test")
|
| 59 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
random.seed(args.seed)
|
| 63 |
+
output_dir = Path(args.output_dir)
|
| 64 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
# Collect all labeled docs grouped by primary label
|
| 67 |
+
by_label = defaultdict(list)
|
| 68 |
+
total = 0
|
| 69 |
+
skipped_nolabel = 0
|
| 70 |
+
skipped_short = 0
|
| 71 |
+
|
| 72 |
+
shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz"))
|
| 73 |
+
if not shard_files:
|
| 74 |
+
raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}")
|
| 75 |
+
|
| 76 |
+
print(f"Found {len(shard_files)} shard(s)")
|
| 77 |
+
|
| 78 |
+
for shard_file in shard_files:
|
| 79 |
+
print(f" Processing {Path(shard_file).name}...")
|
| 80 |
+
with gzip.open(shard_file, "rt") as f:
|
| 81 |
+
for line in f:
|
| 82 |
+
d = json.loads(line)
|
| 83 |
+
labels = d.get("labels", [])
|
| 84 |
+
text = d.get("text", "")
|
| 85 |
+
|
| 86 |
+
if not labels:
|
| 87 |
+
skipped_nolabel += 1
|
| 88 |
+
continue
|
| 89 |
+
if len(text) < args.min_text_len:
|
| 90 |
+
skipped_short += 1
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
cleaned = clean_text(text, args.max_words)
|
| 94 |
+
if not cleaned:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
label_str = " ".join(f"__label__{l}" for l in labels)
|
| 98 |
+
ft_line = f"{label_str} {cleaned}\n"
|
| 99 |
+
|
| 100 |
+
primary = labels[0]
|
| 101 |
+
by_label[primary].append(ft_line)
|
| 102 |
+
total += 1
|
| 103 |
+
|
| 104 |
+
print(f"\nTotal labeled docs: {total}")
|
| 105 |
+
print(f"Skipped (no label): {skipped_nolabel}")
|
| 106 |
+
print(f"Skipped (too short): {skipped_short}")
|
| 107 |
+
|
| 108 |
+
# Raw distribution
|
| 109 |
+
print("\nRaw distribution:")
|
| 110 |
+
for label in sorted(by_label.keys()):
|
| 111 |
+
name = REGISTER_LABELS.get(label, label)
|
| 112 |
+
print(f" {label} ({name}): {len(by_label[label])}")
|
| 113 |
+
|
| 114 |
+
# Balance: oversample minority to median, undersample majority to median
|
| 115 |
+
sizes = {k: len(v) for k, v in by_label.items()}
|
| 116 |
+
sorted_sizes = sorted(sizes.values())
|
| 117 |
+
median_size = sorted_sizes[len(sorted_sizes) // 2]
|
| 118 |
+
target = median_size
|
| 119 |
+
|
| 120 |
+
print(f"\nBalancing target (median): {target}")
|
| 121 |
+
|
| 122 |
+
train_lines = []
|
| 123 |
+
test_lines = []
|
| 124 |
+
|
| 125 |
+
for label, lines in by_label.items():
|
| 126 |
+
random.shuffle(lines)
|
| 127 |
+
|
| 128 |
+
n_test = max(len(lines) // 10, 50)
|
| 129 |
+
test_pool = lines[:n_test]
|
| 130 |
+
train_pool = lines[n_test:]
|
| 131 |
+
|
| 132 |
+
test_lines.extend(test_pool)
|
| 133 |
+
n_train = len(train_pool)
|
| 134 |
+
|
| 135 |
+
if n_train >= target:
|
| 136 |
+
sampled = random.sample(train_pool, target)
|
| 137 |
+
train_lines.extend(sampled)
|
| 138 |
+
print(f" {label}: {n_train} -> {target} (undersampled)")
|
| 139 |
+
else:
|
| 140 |
+
train_lines.extend(train_pool)
|
| 141 |
+
n_needed = target - n_train
|
| 142 |
+
oversampled = random.choices(train_pool, k=n_needed)
|
| 143 |
+
train_lines.extend(oversampled)
|
| 144 |
+
print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})")
|
| 145 |
+
|
| 146 |
+
random.shuffle(train_lines)
|
| 147 |
+
random.shuffle(test_lines)
|
| 148 |
+
|
| 149 |
+
train_path = output_dir / "train.txt"
|
| 150 |
+
test_path = output_dir / "test.txt"
|
| 151 |
+
|
| 152 |
+
with open(train_path, "w") as f:
|
| 153 |
+
f.writelines(train_lines)
|
| 154 |
+
with open(test_path, "w") as f:
|
| 155 |
+
f.writelines(test_lines)
|
| 156 |
+
|
| 157 |
+
print(f"\nTrain: {len(train_lines)} -> {train_path}")
|
| 158 |
+
print(f"Test: {len(test_lines)} -> {test_path}")
|
| 159 |
+
|
| 160 |
+
# Verify balance
|
| 161 |
+
c = Counter()
|
| 162 |
+
for line in train_lines:
|
| 163 |
+
for tok in line.split():
|
| 164 |
+
if tok.startswith("__label__"):
|
| 165 |
+
c[tok] += 1
|
| 166 |
+
print("\nFinal train label distribution:")
|
| 167 |
+
for l, cnt in c.most_common():
|
| 168 |
+
name = REGISTER_LABELS.get(l.replace("__label__", ""), l)
|
| 169 |
+
print(f" {l} ({name}): {cnt}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
main()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train a FastText text register classifier.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model
|
| 6 |
+
|
| 7 |
+
This produces:
|
| 8 |
+
- model/register_fasttext.bin (full model)
|
| 9 |
+
- model/register_fasttext_q.bin (quantized, ~7x smaller)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import fasttext
|
| 13 |
+
import time
|
| 14 |
+
import os
|
| 15 |
+
import argparse
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
parser = argparse.ArgumentParser(description="Train FastText register classifier")
|
| 21 |
+
parser.add_argument("--train", default="./prepared/train.txt", help="Training data file")
|
| 22 |
+
parser.add_argument("--test", default="./prepared/test.txt", help="Test data file")
|
| 23 |
+
parser.add_argument("--output", default="./model", help="Output directory")
|
| 24 |
+
parser.add_argument("--lr", type=float, default=0.5, help="Learning rate")
|
| 25 |
+
parser.add_argument("--epoch", type=int, default=25, help="Number of epochs")
|
| 26 |
+
parser.add_argument("--dim", type=int, default=100, help="Embedding dimension")
|
| 27 |
+
parser.add_argument("--wordNgrams", type=int, default=2, help="Max n-gram length")
|
| 28 |
+
parser.add_argument("--bucket", type=int, default=2000000, help="Hash bucket size")
|
| 29 |
+
parser.add_argument("--thread", type=int, default=8, help="Number of threads")
|
| 30 |
+
parser.add_argument("--min-count", type=int, default=5, help="Min word count")
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
output_dir = Path(args.output)
|
| 34 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
print("=== Training FastText register classifier ===")
|
| 37 |
+
start = time.time()
|
| 38 |
+
|
| 39 |
+
model = fasttext.train_supervised(
|
| 40 |
+
input=args.train,
|
| 41 |
+
lr=args.lr,
|
| 42 |
+
epoch=args.epoch,
|
| 43 |
+
wordNgrams=args.wordNgrams,
|
| 44 |
+
dim=args.dim,
|
| 45 |
+
loss="ova", # one-vs-all for multi-label
|
| 46 |
+
minCount=args.min_count,
|
| 47 |
+
bucket=args.bucket,
|
| 48 |
+
thread=args.thread,
|
| 49 |
+
verbose=2,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
train_time = time.time() - start
|
| 53 |
+
print(f"Training time: {train_time:.1f}s")
|
| 54 |
+
|
| 55 |
+
# Save full model
|
| 56 |
+
full_path = output_dir / "register_fasttext.bin"
|
| 57 |
+
model.save_model(str(full_path))
|
| 58 |
+
size_mb = os.path.getsize(full_path) / 1024 / 1024
|
| 59 |
+
print(f"\nFull model: {full_path} ({size_mb:.1f} MB)")
|
| 60 |
+
|
| 61 |
+
# Evaluate
|
| 62 |
+
print("\n=== Evaluation ===")
|
| 63 |
+
for k in [1, 2]:
|
| 64 |
+
r = model.test(args.test, k=k)
|
| 65 |
+
print(f" k={k}: Precision={r[1]:.4f} Recall={r[2]:.4f} (n={r[0]})")
|
| 66 |
+
|
| 67 |
+
# Quantize
|
| 68 |
+
print("\nQuantizing...")
|
| 69 |
+
model.quantize(input=args.train, retrain=True)
|
| 70 |
+
q_path = output_dir / "register_fasttext_q.bin"
|
| 71 |
+
model.save_model(str(q_path))
|
| 72 |
+
size_q = os.path.getsize(q_path) / 1024 / 1024
|
| 73 |
+
print(f"Quantized model: {q_path} ({size_q:.1f} MB)")
|
| 74 |
+
|
| 75 |
+
r = model.test(args.test, k=1)
|
| 76 |
+
print(f" Quantized k=1: Precision={r[1]:.4f} Recall={r[2]:.4f}")
|
| 77 |
+
|
| 78 |
+
# Speed test
|
| 79 |
+
print("\n=== Speed Test ===")
|
| 80 |
+
test_text = "The algorithm processes data in O(n log n) time complexity."
|
| 81 |
+
start = time.time()
|
| 82 |
+
for _ in range(100000):
|
| 83 |
+
model.predict(test_text)
|
| 84 |
+
elapsed = time.time() - start
|
| 85 |
+
print(f"{100000 / elapsed:.0f} predictions/sec")
|
| 86 |
+
|
| 87 |
+
print("\nDone!")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
main()
|