File size: 3,113 Bytes
3dea709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Train a FastText text register classifier.

Usage:
    python train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model

This produces:
    - model/register_fasttext.bin       (full model)
    - model/register_fasttext_q.bin     (quantized, ~7x smaller)
"""

import fasttext
import time
import os
import argparse
from pathlib import Path


def main():
    parser = argparse.ArgumentParser(description="Train FastText register classifier")
    parser.add_argument("--train", default="./prepared/train.txt", help="Training data file")
    parser.add_argument("--test", default="./prepared/test.txt", help="Test data file")
    parser.add_argument("--output", default="./model", help="Output directory")
    parser.add_argument("--lr", type=float, default=0.5, help="Learning rate")
    parser.add_argument("--epoch", type=int, default=25, help="Number of epochs")
    parser.add_argument("--dim", type=int, default=100, help="Embedding dimension")
    parser.add_argument("--wordNgrams", type=int, default=2, help="Max n-gram length")
    parser.add_argument("--bucket", type=int, default=2000000, help="Hash bucket size")
    parser.add_argument("--thread", type=int, default=8, help="Number of threads")
    parser.add_argument("--min-count", type=int, default=5, help="Min word count")
    args = parser.parse_args()

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("=== Training FastText register classifier ===")
    start = time.time()

    model = fasttext.train_supervised(
        input=args.train,
        lr=args.lr,
        epoch=args.epoch,
        wordNgrams=args.wordNgrams,
        dim=args.dim,
        loss="ova",  # one-vs-all for multi-label
        minCount=args.min_count,
        bucket=args.bucket,
        thread=args.thread,
        verbose=2,
    )

    train_time = time.time() - start
    print(f"Training time: {train_time:.1f}s")

    # Save full model
    full_path = output_dir / "register_fasttext.bin"
    model.save_model(str(full_path))
    size_mb = os.path.getsize(full_path) / 1024 / 1024
    print(f"\nFull model: {full_path} ({size_mb:.1f} MB)")

    # Evaluate
    print("\n=== Evaluation ===")
    for k in [1, 2]:
        r = model.test(args.test, k=k)
        print(f"  k={k}: Precision={r[1]:.4f}  Recall={r[2]:.4f}  (n={r[0]})")

    # Quantize
    print("\nQuantizing...")
    model.quantize(input=args.train, retrain=True)
    q_path = output_dir / "register_fasttext_q.bin"
    model.save_model(str(q_path))
    size_q = os.path.getsize(q_path) / 1024 / 1024
    print(f"Quantized model: {q_path} ({size_q:.1f} MB)")

    r = model.test(args.test, k=1)
    print(f"  Quantized k=1: Precision={r[1]:.4f}  Recall={r[2]:.4f}")

    # Speed test
    print("\n=== Speed Test ===")
    test_text = "The algorithm processes data in O(n log n) time complexity."
    start = time.time()
    for _ in range(100000):
        model.predict(test_text)
    elapsed = time.time() - start
    print(f"{100000 / elapsed:.0f} predictions/sec")

    print("\nDone!")


if __name__ == "__main__":
    main()