File size: 3,614 Bytes
3dea709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Predict text register using the trained FastText model.

Usage:
    # Interactive mode
    python predict.py --model ./model/register_fasttext_q.bin

    # Single text
    python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!"

    # File mode (one text per line)
    python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl
"""

import fasttext
import json
import sys
import argparse
import time


REGISTER_LABELS = {
    "IN": "Informational",
    "NA": "Narrative",
    "OP": "Opinion",
    "IP": "Persuasion",
    "HI": "HowTo",
    "ID": "Discussion",
    "SP": "Spoken",
    "LY": "Lyrical",
}


def predict_one(model, text: str, k: int = 3, threshold: float = 0.1):
    """Predict register labels for a single text."""
    labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold)
    results = []
    for label, prob in zip(labels, probs):
        code = label.replace("__label__", "")
        results.append({
            "label": code,
            "name": REGISTER_LABELS.get(code, code),
            "score": round(float(prob), 4),
        })
    return results


def main():
    parser = argparse.ArgumentParser(description="Predict text register")
    parser.add_argument("--model", required=True, help="Path to FastText .bin model")
    parser.add_argument("--text", help="Single text to classify")
    parser.add_argument("--input", help="Input file (one text per line)")
    parser.add_argument("--output", help="Output JSONL file")
    parser.add_argument("--k", type=int, default=3, help="Top-k labels to return")
    parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold")
    args = parser.parse_args()

    # Suppress load warning
    try:
        fasttext.FastText.eprint = lambda x: None
    except Exception:
        pass

    model = fasttext.load_model(args.model)

    if args.text:
        # Single prediction
        results = predict_one(model, args.text, args.k, args.threshold)
        for r in results:
            print(f"  {r['name']:<15} ({r['label']})  {r['score']:.3f}")

    elif args.input:
        # Batch mode
        out_f = open(args.output, "w") if args.output else sys.stdout
        count = 0
        start = time.time()

        with open(args.input) as f:
            for line in f:
                text = line.strip()
                if not text:
                    continue
                results = predict_one(model, text, args.k, args.threshold)
                record = {"text": text[:200], "predictions": results}
                out_f.write(json.dumps(record) + "\n")
                count += 1

        elapsed = time.time() - start
        if args.output:
            out_f.close()
        print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr)

    else:
        # Interactive mode
        print("Text Register Classifier (type 'quit' to exit)")
        print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}")
        print()
        while True:
            try:
                text = input("> ").strip()
            except (EOFError, KeyboardInterrupt):
                break
            if text.lower() in ("quit", "exit", "q"):
                break
            if not text:
                continue
            results = predict_one(model, text, args.k, args.threshold)
            for r in results:
                print(f"  {r['name']:<15} ({r['label']})  {r['score']:.3f}")
            print()


if __name__ == "__main__":
    main()