oneryalcin commited on
Commit
3dea709
·
verified ·
1 Parent(s): 4b21fb6

Add text register FastText classifier with training scripts

Browse files
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text Register FastText Classifier
2
+
3
+ A FastText classifier that detects the **communicative register** (text type) of any English text at ~500k predictions/sec on CPU.
4
+
5
+ ## Labels
6
+
7
+ | Code | Register | Description | Example |
8
+ |------|----------|-------------|---------|
9
+ | `IN` | Informational | Factual, encyclopedic, descriptive | Wikipedia articles, reports |
10
+ | `NA` | Narrative | Story-like, temporal sequence of events | News stories, fiction, blog posts |
11
+ | `OP` | Opinion | Subjective evaluation, personal views | Reviews, editorials, comments |
12
+ | `IP` | Persuasion | Attempts to convince or sell | Marketing copy, ads, fundraising |
13
+ | `HI` | HowTo | Instructions, procedures, recipes | Tutorials, manuals, FAQs |
14
+ | `ID` | Discussion | Interactive, forum-style dialogue | Forum threads, Q&A, comments |
15
+ | `SP` | Spoken | Transcribed or spoken-style text | Interviews, podcasts, speeches |
16
+ | `LY` | Lyrical | Poetic, artistic, song-like | Poetry, song lyrics, creative prose |
17
+
18
+ Based on the Biber & Egbert (2018) register taxonomy. Multi-label supported (a text can be both Informational and Narrative).
19
+
20
+ ## Quick Start
21
+
22
+ ```python
23
+ import fasttext
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ # Download model (quantized, 151 MB)
27
+ model_path = hf_hub_download(
28
+ "oneryalcin/text-register-fasttext-classifier",
29
+ "register_fasttext_q.bin"
30
+ )
31
+ model = fasttext.load_model(model_path)
32
+
33
+ # Predict
34
+ labels, probs = model.predict("Buy now and save 50%! Limited time offer!", k=3)
35
+ # -> [('__label__IP', 1.0), ...] # IP = Persuasion
36
+ ```
37
+
38
+ > **Note**: If you get a numpy error, pin `numpy<2`: `pip install "numpy<2"`
39
+
40
+ ## Performance
41
+
42
+ Trained on 10 English shards from [TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar) (~1.9M documents), balanced via oversampling/undersampling to median class size.
43
+
44
+ ### Overall Metrics
45
+
46
+ | Metric | Full Model | Quantized |
47
+ |--------|-----------|-----------|
48
+ | Precision@1 | 0.831 | 0.796 |
49
+ | Recall@1 | 0.759 | 0.727 |
50
+ | Precision@2 | 0.491 | — |
51
+ | Recall@2 | 0.898 | — |
52
+ | Speed | ~500k pred/s | ~500k pred/s |
53
+ | Size | 1.1 GB | 151 MB |
54
+
55
+ ### Per-Class F1 (threshold=0.3, k=2)
56
+
57
+ | Register | Precision | Recall | F1 | Test Support |
58
+ |----------|-----------|--------|-----|-------------|
59
+ | Informational | 0.910 | 0.666 | 0.769 | 108,672 |
60
+ | Narrative | 0.764 | 0.766 | 0.765 | 44,238 |
61
+ | Discussion | 0.640 | 0.774 | 0.701 | 7,420 |
62
+ | Persuasion | 0.553 | 0.794 | 0.652 | 19,193 |
63
+ | Opinion | 0.567 | 0.736 | 0.640 | 20,014 |
64
+ | HowTo | 0.515 | 0.766 | 0.616 | 7,281 |
65
+ | Spoken | 0.551 | 0.513 | 0.531 | 831 |
66
+ | Lyrical | 0.657 | 0.442 | 0.529 | 251 |
67
+
68
+ ### Example Predictions
69
+
70
+ ```
71
+ "The company reported revenue of $4.2 billion..." -> Informational (1.00), Narrative (0.99)
72
+ "Once upon a time in a small village..." -> Narrative
73
+ "I honestly think this movie is terrible..." -> Opinion (1.00)
74
+ "To install the package, first run pip install..." -> HowTo (1.00)
75
+ "Buy now and save 50%! Limited time offer..." -> Persuasion (1.00)
76
+ "So like, I was telling her yesterday..." -> Spoken (1.00)
77
+ "I've been walking these streets alone..." -> Lyrical (1.00)
78
+ "Hey everyone! What do you think about..." -> Discussion (1.00)
79
+ "Introducing the revolutionary SkinGlow Pro..." -> Persuasion (1.00)
80
+ ```
81
+
82
+ ## Use Cases
83
+
84
+ - **Data curation**: Filter pretraining corpora by register (e.g., keep only Informational + HowTo)
85
+ - **Content routing**: Route incoming text to different processing pipelines
86
+ - **Boilerplate removal**: Flag Persuasion/Marketing text in document corpora
87
+ - **Signal extraction**: Identify which paragraphs in a document carry factual vs opinion content
88
+ - **RAG preprocessing**: Score chunks by register before feeding to LLMs
89
+
90
+ ## Reproduce from Scratch
91
+
92
+ ### 1. Download data
93
+
94
+ ```bash
95
+ pip install huggingface_hub
96
+
97
+ # Download 10 English shards (~4 GB)
98
+ for i in $(seq 0 9); do
99
+ hf download TurkuNLP/register_oscar \
100
+ $(printf "en/en_%05d.jsonl.gz" $i) \
101
+ --repo-type dataset --local-dir ./data
102
+ done
103
+ ```
104
+
105
+ ### 2. Prepare balanced training data
106
+
107
+ ```bash
108
+ python scripts/prepare_data.py --data-dir ./data/en --output-dir ./prepared
109
+ ```
110
+
111
+ ### 3. Train
112
+
113
+ ```bash
114
+ pip install fasttext-wheel "numpy<2"
115
+ python scripts/train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model
116
+ ```
117
+
118
+ ### 4. Predict
119
+
120
+ ```bash
121
+ # Interactive
122
+ python scripts/predict.py --model ./model/register_fasttext_q.bin
123
+
124
+ # Single text
125
+ python scripts/predict.py --model ./model/register_fasttext_q.bin --text "Buy now! 50% off!"
126
+
127
+ # Batch
128
+ python scripts/predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output out.jsonl
129
+ ```
130
+
131
+ ## Training Details
132
+
133
+ - **Source data**: [TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar) (English, 10 shards, ~1.9M labeled documents)
134
+ - **Balancing**: Minority classes oversampled, majority classes undersampled to median class size (~129k per class)
135
+ - **Architecture**: FastText supervised with bigrams, 100-dim embeddings, one-vs-all loss
136
+ - **Hyperparameters**: lr=0.5, epoch=25, wordNgrams=2, dim=100, loss=ova, bucket=2M
137
+ - **Text preprocessing**: Whitespace collapsed, truncated to 500 words
138
+
139
+ ## Limitations
140
+
141
+ - **Spoken & Lyrical** classes have lower F1 (~0.53) due to limited unique training data even after oversampling
142
+ - Trained on web text only — may not generalize well to domain-specific text (legal, medical)
143
+ - Bag-of-words model — does not understand word order or deep semantics
144
+ - English only (the source dataset has other languages that could be used for multilingual training)
145
+
146
+ ## Citation
147
+
148
+ If you use this model, please cite the source dataset:
149
+
150
+ ```bibtex
151
+ @inproceedings{register_oscar,
152
+ title={Multilingual register classification on the full OSCAR data},
153
+ author={R{\"o}nnqvist, Samuel and others},
154
+ year={2023},
155
+ note={TurkuNLP, University of Turku}
156
+ }
157
+
158
+ @article{biber2018register,
159
+ title={Register as a predictor of linguistic variation},
160
+ author={Biber, Douglas and Egbert, Jesse},
161
+ journal={Corpus Linguistics and Linguistic Theory},
162
+ year={2018}
163
+ }
164
+ ```
165
+
166
+ ## License
167
+
168
+ The model weights inherit the license of the source dataset ([TurkuNLP/register_oscar](https://huggingface.co/datasets/TurkuNLP/register_oscar)). Scripts are released under MIT.
register_fasttext.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e76a01fa9946bd26ab2eeeea1842ff643cc486634c0f4db4dbe85b6b7c78017
3
+ size 1156314566
register_fasttext_q.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b1d55be8d490dbbcb17773592e367d58fb857dfe0a603c322246aef6855de86
3
+ size 158362937
scripts/predict.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Predict text register using the trained FastText model.
3
+
4
+ Usage:
5
+ # Interactive mode
6
+ python predict.py --model ./model/register_fasttext_q.bin
7
+
8
+ # Single text
9
+ python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!"
10
+
11
+ # File mode (one text per line)
12
+ python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl
13
+ """
14
+
15
+ import fasttext
16
+ import json
17
+ import sys
18
+ import argparse
19
+ import time
20
+
21
+
22
+ REGISTER_LABELS = {
23
+ "IN": "Informational",
24
+ "NA": "Narrative",
25
+ "OP": "Opinion",
26
+ "IP": "Persuasion",
27
+ "HI": "HowTo",
28
+ "ID": "Discussion",
29
+ "SP": "Spoken",
30
+ "LY": "Lyrical",
31
+ }
32
+
33
+
34
+ def predict_one(model, text: str, k: int = 3, threshold: float = 0.1):
35
+ """Predict register labels for a single text."""
36
+ labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold)
37
+ results = []
38
+ for label, prob in zip(labels, probs):
39
+ code = label.replace("__label__", "")
40
+ results.append({
41
+ "label": code,
42
+ "name": REGISTER_LABELS.get(code, code),
43
+ "score": round(float(prob), 4),
44
+ })
45
+ return results
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser(description="Predict text register")
50
+ parser.add_argument("--model", required=True, help="Path to FastText .bin model")
51
+ parser.add_argument("--text", help="Single text to classify")
52
+ parser.add_argument("--input", help="Input file (one text per line)")
53
+ parser.add_argument("--output", help="Output JSONL file")
54
+ parser.add_argument("--k", type=int, default=3, help="Top-k labels to return")
55
+ parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold")
56
+ args = parser.parse_args()
57
+
58
+ # Suppress load warning
59
+ try:
60
+ fasttext.FastText.eprint = lambda x: None
61
+ except Exception:
62
+ pass
63
+
64
+ model = fasttext.load_model(args.model)
65
+
66
+ if args.text:
67
+ # Single prediction
68
+ results = predict_one(model, args.text, args.k, args.threshold)
69
+ for r in results:
70
+ print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
71
+
72
+ elif args.input:
73
+ # Batch mode
74
+ out_f = open(args.output, "w") if args.output else sys.stdout
75
+ count = 0
76
+ start = time.time()
77
+
78
+ with open(args.input) as f:
79
+ for line in f:
80
+ text = line.strip()
81
+ if not text:
82
+ continue
83
+ results = predict_one(model, text, args.k, args.threshold)
84
+ record = {"text": text[:200], "predictions": results}
85
+ out_f.write(json.dumps(record) + "\n")
86
+ count += 1
87
+
88
+ elapsed = time.time() - start
89
+ if args.output:
90
+ out_f.close()
91
+ print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr)
92
+
93
+ else:
94
+ # Interactive mode
95
+ print("Text Register Classifier (type 'quit' to exit)")
96
+ print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}")
97
+ print()
98
+ while True:
99
+ try:
100
+ text = input("> ").strip()
101
+ except (EOFError, KeyboardInterrupt):
102
+ break
103
+ if text.lower() in ("quit", "exit", "q"):
104
+ break
105
+ if not text:
106
+ continue
107
+ results = predict_one(model, text, args.k, args.threshold)
108
+ for r in results:
109
+ print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
110
+ print()
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
scripts/prepare_data.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare balanced FastText training data from TurkuNLP/register_oscar dataset.
3
+
4
+ Downloads English shards, extracts labeled documents, and creates a balanced
5
+ training set by oversampling minority classes and undersampling majority classes
6
+ to the median class size.
7
+
8
+ Requirements:
9
+ pip install huggingface_hub
10
+
11
+ Usage:
12
+ # Download shards first:
13
+ for i in $(seq 0 9); do
14
+ hf download TurkuNLP/register_oscar \
15
+ $(printf "en/en_%05d.jsonl.gz" $i) \
16
+ --repo-type dataset --local-dir ./data
17
+ done
18
+
19
+ # Then run:
20
+ python prepare_data.py --data-dir ./data/en --output-dir ./prepared
21
+ """
22
+
23
+ import json
24
+ import gzip
25
+ import re
26
+ import random
27
+ import glob
28
+ import argparse
29
+ from collections import Counter, defaultdict
30
+ from pathlib import Path
31
+
32
+
33
+ REGISTER_LABELS = {
34
+ "IN": "Informational",
35
+ "NA": "Narrative",
36
+ "OP": "Opinion",
37
+ "IP": "Persuasion",
38
+ "HI": "HowTo",
39
+ "ID": "Discussion",
40
+ "SP": "Spoken",
41
+ "LY": "Lyrical",
42
+ }
43
+
44
+
45
+ def clean_text(text: str, max_words: int = 500) -> str:
46
+ """Collapse whitespace and truncate to max_words."""
47
+ text = re.sub(r"\s+", " ", text).strip()
48
+ words = text.split()[:max_words]
49
+ return " ".join(words)
50
+
51
+
52
+ def main():
53
+ parser = argparse.ArgumentParser(description="Prepare balanced FastText training data")
54
+ parser.add_argument("--data-dir", default="./data/en", help="Directory with .jsonl.gz shards")
55
+ parser.add_argument("--output-dir", default="./prepared", help="Output directory for train/test files")
56
+ parser.add_argument("--max-words", type=int, default=500, help="Max words per document")
57
+ parser.add_argument("--min-text-len", type=int, default=50, help="Min character length to keep")
58
+ parser.add_argument("--test-ratio", type=float, default=0.1, help="Fraction held out for test")
59
+ parser.add_argument("--seed", type=int, default=42)
60
+ args = parser.parse_args()
61
+
62
+ random.seed(args.seed)
63
+ output_dir = Path(args.output_dir)
64
+ output_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ # Collect all labeled docs grouped by primary label
67
+ by_label = defaultdict(list)
68
+ total = 0
69
+ skipped_nolabel = 0
70
+ skipped_short = 0
71
+
72
+ shard_files = sorted(glob.glob(f"{args.data_dir}/*.jsonl.gz"))
73
+ if not shard_files:
74
+ raise FileNotFoundError(f"No .jsonl.gz files found in {args.data_dir}")
75
+
76
+ print(f"Found {len(shard_files)} shard(s)")
77
+
78
+ for shard_file in shard_files:
79
+ print(f" Processing {Path(shard_file).name}...")
80
+ with gzip.open(shard_file, "rt") as f:
81
+ for line in f:
82
+ d = json.loads(line)
83
+ labels = d.get("labels", [])
84
+ text = d.get("text", "")
85
+
86
+ if not labels:
87
+ skipped_nolabel += 1
88
+ continue
89
+ if len(text) < args.min_text_len:
90
+ skipped_short += 1
91
+ continue
92
+
93
+ cleaned = clean_text(text, args.max_words)
94
+ if not cleaned:
95
+ continue
96
+
97
+ label_str = " ".join(f"__label__{l}" for l in labels)
98
+ ft_line = f"{label_str} {cleaned}\n"
99
+
100
+ primary = labels[0]
101
+ by_label[primary].append(ft_line)
102
+ total += 1
103
+
104
+ print(f"\nTotal labeled docs: {total}")
105
+ print(f"Skipped (no label): {skipped_nolabel}")
106
+ print(f"Skipped (too short): {skipped_short}")
107
+
108
+ # Raw distribution
109
+ print("\nRaw distribution:")
110
+ for label in sorted(by_label.keys()):
111
+ name = REGISTER_LABELS.get(label, label)
112
+ print(f" {label} ({name}): {len(by_label[label])}")
113
+
114
+ # Balance: oversample minority to median, undersample majority to median
115
+ sizes = {k: len(v) for k, v in by_label.items()}
116
+ sorted_sizes = sorted(sizes.values())
117
+ median_size = sorted_sizes[len(sorted_sizes) // 2]
118
+ target = median_size
119
+
120
+ print(f"\nBalancing target (median): {target}")
121
+
122
+ train_lines = []
123
+ test_lines = []
124
+
125
+ for label, lines in by_label.items():
126
+ random.shuffle(lines)
127
+
128
+ n_test = max(len(lines) // 10, 50)
129
+ test_pool = lines[:n_test]
130
+ train_pool = lines[n_test:]
131
+
132
+ test_lines.extend(test_pool)
133
+ n_train = len(train_pool)
134
+
135
+ if n_train >= target:
136
+ sampled = random.sample(train_pool, target)
137
+ train_lines.extend(sampled)
138
+ print(f" {label}: {n_train} -> {target} (undersampled)")
139
+ else:
140
+ train_lines.extend(train_pool)
141
+ n_needed = target - n_train
142
+ oversampled = random.choices(train_pool, k=n_needed)
143
+ train_lines.extend(oversampled)
144
+ print(f" {label}: {n_train} -> {target} (oversampled +{n_needed})")
145
+
146
+ random.shuffle(train_lines)
147
+ random.shuffle(test_lines)
148
+
149
+ train_path = output_dir / "train.txt"
150
+ test_path = output_dir / "test.txt"
151
+
152
+ with open(train_path, "w") as f:
153
+ f.writelines(train_lines)
154
+ with open(test_path, "w") as f:
155
+ f.writelines(test_lines)
156
+
157
+ print(f"\nTrain: {len(train_lines)} -> {train_path}")
158
+ print(f"Test: {len(test_lines)} -> {test_path}")
159
+
160
+ # Verify balance
161
+ c = Counter()
162
+ for line in train_lines:
163
+ for tok in line.split():
164
+ if tok.startswith("__label__"):
165
+ c[tok] += 1
166
+ print("\nFinal train label distribution:")
167
+ for l, cnt in c.most_common():
168
+ name = REGISTER_LABELS.get(l.replace("__label__", ""), l)
169
+ print(f" {l} ({name}): {cnt}")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()
scripts/train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a FastText text register classifier.
3
+
4
+ Usage:
5
+ python train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model
6
+
7
+ This produces:
8
+ - model/register_fasttext.bin (full model)
9
+ - model/register_fasttext_q.bin (quantized, ~7x smaller)
10
+ """
11
+
12
+ import fasttext
13
+ import time
14
+ import os
15
+ import argparse
16
+ from pathlib import Path
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser(description="Train FastText register classifier")
21
+ parser.add_argument("--train", default="./prepared/train.txt", help="Training data file")
22
+ parser.add_argument("--test", default="./prepared/test.txt", help="Test data file")
23
+ parser.add_argument("--output", default="./model", help="Output directory")
24
+ parser.add_argument("--lr", type=float, default=0.5, help="Learning rate")
25
+ parser.add_argument("--epoch", type=int, default=25, help="Number of epochs")
26
+ parser.add_argument("--dim", type=int, default=100, help="Embedding dimension")
27
+ parser.add_argument("--wordNgrams", type=int, default=2, help="Max n-gram length")
28
+ parser.add_argument("--bucket", type=int, default=2000000, help="Hash bucket size")
29
+ parser.add_argument("--thread", type=int, default=8, help="Number of threads")
30
+ parser.add_argument("--min-count", type=int, default=5, help="Min word count")
31
+ args = parser.parse_args()
32
+
33
+ output_dir = Path(args.output)
34
+ output_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ print("=== Training FastText register classifier ===")
37
+ start = time.time()
38
+
39
+ model = fasttext.train_supervised(
40
+ input=args.train,
41
+ lr=args.lr,
42
+ epoch=args.epoch,
43
+ wordNgrams=args.wordNgrams,
44
+ dim=args.dim,
45
+ loss="ova", # one-vs-all for multi-label
46
+ minCount=args.min_count,
47
+ bucket=args.bucket,
48
+ thread=args.thread,
49
+ verbose=2,
50
+ )
51
+
52
+ train_time = time.time() - start
53
+ print(f"Training time: {train_time:.1f}s")
54
+
55
+ # Save full model
56
+ full_path = output_dir / "register_fasttext.bin"
57
+ model.save_model(str(full_path))
58
+ size_mb = os.path.getsize(full_path) / 1024 / 1024
59
+ print(f"\nFull model: {full_path} ({size_mb:.1f} MB)")
60
+
61
+ # Evaluate
62
+ print("\n=== Evaluation ===")
63
+ for k in [1, 2]:
64
+ r = model.test(args.test, k=k)
65
+ print(f" k={k}: Precision={r[1]:.4f} Recall={r[2]:.4f} (n={r[0]})")
66
+
67
+ # Quantize
68
+ print("\nQuantizing...")
69
+ model.quantize(input=args.train, retrain=True)
70
+ q_path = output_dir / "register_fasttext_q.bin"
71
+ model.save_model(str(q_path))
72
+ size_q = os.path.getsize(q_path) / 1024 / 1024
73
+ print(f"Quantized model: {q_path} ({size_q:.1f} MB)")
74
+
75
+ r = model.test(args.test, k=1)
76
+ print(f" Quantized k=1: Precision={r[1]:.4f} Recall={r[2]:.4f}")
77
+
78
+ # Speed test
79
+ print("\n=== Speed Test ===")
80
+ test_text = "The algorithm processes data in O(n log n) time complexity."
81
+ start = time.time()
82
+ for _ in range(100000):
83
+ model.predict(test_text)
84
+ elapsed = time.time() - start
85
+ print(f"{100000 / elapsed:.0f} predictions/sec")
86
+
87
+ print("\nDone!")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()