oneryalcin's picture
Add text register FastText classifier with training scripts
3dea709 verified
"""
Predict text register using the trained FastText model.
Usage:
# Interactive mode
python predict.py --model ./model/register_fasttext_q.bin
# Single text
python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!"
# File mode (one text per line)
python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl
"""
import fasttext
import json
import sys
import argparse
import time
REGISTER_LABELS = {
"IN": "Informational",
"NA": "Narrative",
"OP": "Opinion",
"IP": "Persuasion",
"HI": "HowTo",
"ID": "Discussion",
"SP": "Spoken",
"LY": "Lyrical",
}
def predict_one(model, text: str, k: int = 3, threshold: float = 0.1):
"""Predict register labels for a single text."""
labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold)
results = []
for label, prob in zip(labels, probs):
code = label.replace("__label__", "")
results.append({
"label": code,
"name": REGISTER_LABELS.get(code, code),
"score": round(float(prob), 4),
})
return results
def main():
parser = argparse.ArgumentParser(description="Predict text register")
parser.add_argument("--model", required=True, help="Path to FastText .bin model")
parser.add_argument("--text", help="Single text to classify")
parser.add_argument("--input", help="Input file (one text per line)")
parser.add_argument("--output", help="Output JSONL file")
parser.add_argument("--k", type=int, default=3, help="Top-k labels to return")
parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold")
args = parser.parse_args()
# Suppress load warning
try:
fasttext.FastText.eprint = lambda x: None
except Exception:
pass
model = fasttext.load_model(args.model)
if args.text:
# Single prediction
results = predict_one(model, args.text, args.k, args.threshold)
for r in results:
print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
elif args.input:
# Batch mode
out_f = open(args.output, "w") if args.output else sys.stdout
count = 0
start = time.time()
with open(args.input) as f:
for line in f:
text = line.strip()
if not text:
continue
results = predict_one(model, text, args.k, args.threshold)
record = {"text": text[:200], "predictions": results}
out_f.write(json.dumps(record) + "\n")
count += 1
elapsed = time.time() - start
if args.output:
out_f.close()
print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr)
else:
# Interactive mode
print("Text Register Classifier (type 'quit' to exit)")
print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}")
print()
while True:
try:
text = input("> ").strip()
except (EOFError, KeyboardInterrupt):
break
if text.lower() in ("quit", "exit", "q"):
break
if not text:
continue
results = predict_one(model, text, args.k, args.threshold)
for r in results:
print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}")
print()
if __name__ == "__main__":
main()