oneryalcin's picture
Add text register FastText classifier with training scripts
3dea709 verified
"""
Train a FastText text register classifier.
Usage:
python train.py --train ./prepared/train.txt --test ./prepared/test.txt --output ./model
This produces:
- model/register_fasttext.bin (full model)
- model/register_fasttext_q.bin (quantized, ~7x smaller)
"""
import fasttext
import time
import os
import argparse
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description="Train FastText register classifier")
parser.add_argument("--train", default="./prepared/train.txt", help="Training data file")
parser.add_argument("--test", default="./prepared/test.txt", help="Test data file")
parser.add_argument("--output", default="./model", help="Output directory")
parser.add_argument("--lr", type=float, default=0.5, help="Learning rate")
parser.add_argument("--epoch", type=int, default=25, help="Number of epochs")
parser.add_argument("--dim", type=int, default=100, help="Embedding dimension")
parser.add_argument("--wordNgrams", type=int, default=2, help="Max n-gram length")
parser.add_argument("--bucket", type=int, default=2000000, help="Hash bucket size")
parser.add_argument("--thread", type=int, default=8, help="Number of threads")
parser.add_argument("--min-count", type=int, default=5, help="Min word count")
args = parser.parse_args()
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
print("=== Training FastText register classifier ===")
start = time.time()
model = fasttext.train_supervised(
input=args.train,
lr=args.lr,
epoch=args.epoch,
wordNgrams=args.wordNgrams,
dim=args.dim,
loss="ova", # one-vs-all for multi-label
minCount=args.min_count,
bucket=args.bucket,
thread=args.thread,
verbose=2,
)
train_time = time.time() - start
print(f"Training time: {train_time:.1f}s")
# Save full model
full_path = output_dir / "register_fasttext.bin"
model.save_model(str(full_path))
size_mb = os.path.getsize(full_path) / 1024 / 1024
print(f"\nFull model: {full_path} ({size_mb:.1f} MB)")
# Evaluate
print("\n=== Evaluation ===")
for k in [1, 2]:
r = model.test(args.test, k=k)
print(f" k={k}: Precision={r[1]:.4f} Recall={r[2]:.4f} (n={r[0]})")
# Quantize
print("\nQuantizing...")
model.quantize(input=args.train, retrain=True)
q_path = output_dir / "register_fasttext_q.bin"
model.save_model(str(q_path))
size_q = os.path.getsize(q_path) / 1024 / 1024
print(f"Quantized model: {q_path} ({size_q:.1f} MB)")
r = model.test(args.test, k=1)
print(f" Quantized k=1: Precision={r[1]:.4f} Recall={r[2]:.4f}")
# Speed test
print("\n=== Speed Test ===")
test_text = "The algorithm processes data in O(n log n) time complexity."
start = time.time()
for _ in range(100000):
model.predict(test_text)
elapsed = time.time() - start
print(f"{100000 / elapsed:.0f} predictions/sec")
print("\nDone!")
if __name__ == "__main__":
main()