| """ |
| Predict text register using the trained FastText model. |
| |
| Usage: |
| # Interactive mode |
| python predict.py --model ./model/register_fasttext_q.bin |
| |
| # Single text |
| python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!" |
| |
| # File mode (one text per line) |
| python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl |
| """ |
|
|
| import fasttext |
| import json |
| import sys |
| import argparse |
| import time |
|
|
|
|
| REGISTER_LABELS = { |
| "IN": "Informational", |
| "NA": "Narrative", |
| "OP": "Opinion", |
| "IP": "Persuasion", |
| "HI": "HowTo", |
| "ID": "Discussion", |
| "SP": "Spoken", |
| "LY": "Lyrical", |
| } |
|
|
|
|
| def predict_one(model, text: str, k: int = 3, threshold: float = 0.1): |
| """Predict register labels for a single text.""" |
| labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold) |
| results = [] |
| for label, prob in zip(labels, probs): |
| code = label.replace("__label__", "") |
| results.append({ |
| "label": code, |
| "name": REGISTER_LABELS.get(code, code), |
| "score": round(float(prob), 4), |
| }) |
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Predict text register") |
| parser.add_argument("--model", required=True, help="Path to FastText .bin model") |
| parser.add_argument("--text", help="Single text to classify") |
| parser.add_argument("--input", help="Input file (one text per line)") |
| parser.add_argument("--output", help="Output JSONL file") |
| parser.add_argument("--k", type=int, default=3, help="Top-k labels to return") |
| parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold") |
| args = parser.parse_args() |
|
|
| |
| try: |
| fasttext.FastText.eprint = lambda x: None |
| except Exception: |
| pass |
|
|
| model = fasttext.load_model(args.model) |
|
|
| if args.text: |
| |
| results = predict_one(model, args.text, args.k, args.threshold) |
| for r in results: |
| print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}") |
|
|
| elif args.input: |
| |
| out_f = open(args.output, "w") if args.output else sys.stdout |
| count = 0 |
| start = time.time() |
|
|
| with open(args.input) as f: |
| for line in f: |
| text = line.strip() |
| if not text: |
| continue |
| results = predict_one(model, text, args.k, args.threshold) |
| record = {"text": text[:200], "predictions": results} |
| out_f.write(json.dumps(record) + "\n") |
| count += 1 |
|
|
| elapsed = time.time() - start |
| if args.output: |
| out_f.close() |
| print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr) |
|
|
| else: |
| |
| print("Text Register Classifier (type 'quit' to exit)") |
| print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}") |
| print() |
| while True: |
| try: |
| text = input("> ").strip() |
| except (EOFError, KeyboardInterrupt): |
| break |
| if text.lower() in ("quit", "exit", "q"): |
| break |
| if not text: |
| continue |
| results = predict_one(model, text, args.k, args.threshold) |
| for r in results: |
| print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|