""" Predict text register using the trained FastText model. Usage: # Interactive mode python predict.py --model ./model/register_fasttext_q.bin # Single text python predict.py --model ./model/register_fasttext_q.bin --text "Buy now! Limited offer!" # File mode (one text per line) python predict.py --model ./model/register_fasttext_q.bin --input texts.txt --output predictions.jsonl """ import fasttext import json import sys import argparse import time REGISTER_LABELS = { "IN": "Informational", "NA": "Narrative", "OP": "Opinion", "IP": "Persuasion", "HI": "HowTo", "ID": "Discussion", "SP": "Spoken", "LY": "Lyrical", } def predict_one(model, text: str, k: int = 3, threshold: float = 0.1): """Predict register labels for a single text.""" labels, probs = model.predict(text.replace("\n", " "), k=k, threshold=threshold) results = [] for label, prob in zip(labels, probs): code = label.replace("__label__", "") results.append({ "label": code, "name": REGISTER_LABELS.get(code, code), "score": round(float(prob), 4), }) return results def main(): parser = argparse.ArgumentParser(description="Predict text register") parser.add_argument("--model", required=True, help="Path to FastText .bin model") parser.add_argument("--text", help="Single text to classify") parser.add_argument("--input", help="Input file (one text per line)") parser.add_argument("--output", help="Output JSONL file") parser.add_argument("--k", type=int, default=3, help="Top-k labels to return") parser.add_argument("--threshold", type=float, default=0.1, help="Min probability threshold") args = parser.parse_args() # Suppress load warning try: fasttext.FastText.eprint = lambda x: None except Exception: pass model = fasttext.load_model(args.model) if args.text: # Single prediction results = predict_one(model, args.text, args.k, args.threshold) for r in results: print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}") elif args.input: # Batch mode out_f = open(args.output, "w") if args.output else sys.stdout count = 0 start = time.time() with open(args.input) as f: for line in f: text = line.strip() if not text: continue results = predict_one(model, text, args.k, args.threshold) record = {"text": text[:200], "predictions": results} out_f.write(json.dumps(record) + "\n") count += 1 elapsed = time.time() - start if args.output: out_f.close() print(f"Processed {count} texts in {elapsed:.2f}s ({count / elapsed:.0f}/sec)", file=sys.stderr) else: # Interactive mode print("Text Register Classifier (type 'quit' to exit)") print(f"Labels: {', '.join(f'{k}={v}' for k, v in REGISTER_LABELS.items())}") print() while True: try: text = input("> ").strip() except (EOFError, KeyboardInterrupt): break if text.lower() in ("quit", "exit", "q"): break if not text: continue results = predict_one(model, text, args.k, args.threshold) for r in results: print(f" {r['name']:<15} ({r['label']}) {r['score']:.3f}") print() if __name__ == "__main__": main()