| """ |
| Train a FastText text register classifier. |
| |
| Usage: |
| python train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model |
| |
| This produces: |
| - model/register_fasttext.bin (full model) |
| - model/register_fasttext_q.bin (quantized, ~7x smaller) |
| """ |
|
|
| import fasttext |
| import time |
| import os |
| import argparse |
| from pathlib import Path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train FastText register classifier") |
| parser.add_argument("--train", default="./prepared/train.txt", help="Training data file") |
| parser.add_argument("--test", default="./prepared/test.txt", help="Test data file") |
| parser.add_argument("--output", default="./model", help="Output directory") |
| parser.add_argument("--lr", type=float, default=0.5, help="Learning rate") |
| parser.add_argument("--epoch", type=int, default=25, help="Number of epochs") |
| parser.add_argument("--dim", type=int, default=100, help="Embedding dimension") |
| parser.add_argument("--wordNgrams", type=int, default=2, help="Max n-gram length") |
| parser.add_argument("--bucket", type=int, default=2000000, help="Hash bucket size") |
| parser.add_argument("--thread", type=int, default=8, help="Number of threads") |
| parser.add_argument("--min-count", type=int, default=5, help="Min word count") |
| args = parser.parse_args() |
|
|
| output_dir = Path(args.output) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print("=== Training FastText register classifier ===") |
| start = time.time() |
|
|
| model = fasttext.train_supervised( |
| input=args.train, |
| lr=args.lr, |
| epoch=args.epoch, |
| wordNgrams=args.wordNgrams, |
| dim=args.dim, |
| loss="ova", |
| minCount=args.min_count, |
| bucket=args.bucket, |
| thread=args.thread, |
| verbose=2, |
| ) |
|
|
| train_time = time.time() - start |
| print(f"Training time: {train_time:.1f}s") |
|
|
| |
| full_path = output_dir / "register_fasttext.bin" |
| model.save_model(str(full_path)) |
| size_mb = os.path.getsize(full_path) / 1024 / 1024 |
| print(f"\nFull model: {full_path} ({size_mb:.1f} MB)") |
|
|
| |
| print("\n=== Evaluation ===") |
| for k in [1, 2]: |
| r = model.test(args.test, k=k) |
| print(f" k={k}: Precision={r[1]:.4f} Recall={r[2]:.4f} (n={r[0]})") |
|
|
| |
| print("\nQuantizing...") |
| model.quantize(input=args.train, retrain=True) |
| q_path = output_dir / "register_fasttext_q.bin" |
| model.save_model(str(q_path)) |
| size_q = os.path.getsize(q_path) / 1024 / 1024 |
| print(f"Quantized model: {q_path} ({size_q:.1f} MB)") |
|
|
| r = model.test(args.test, k=1) |
| print(f" Quantized k=1: Precision={r[1]:.4f} Recall={r[2]:.4f}") |
|
|
| |
| print("\n=== Speed Test ===") |
| test_text = "The algorithm processes data in O(n log n) time complexity." |
| start = time.time() |
| for _ in range(100000): |
| model.predict(test_text) |
| elapsed = time.time() - start |
| print(f"{100000 / elapsed:.0f} predictions/sec") |
|
|
| print("\nDone!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|