| |
| """ |
| Simple inference script for Polish Twitter Emotion Classifier (ONNX). |
| |
| Predicts emotions and sentiment in Polish text using default threshold (0.5). |
| |
| Usage: |
| python predict.py "Your Polish text here" |
| python predict.py "Text" --threshold 0.3 |
| |
| Requirements: |
| pip install optimum[onnxruntime] transformers numpy |
| """ |
|
|
| import argparse |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| from optimum.onnxruntime import ORTModelForSequenceClassification |
| from transformers import AutoTokenizer |
|
|
|
|
| DEFAULT_THRESHOLD = 0.5 |
|
|
|
|
| def preprocess_text(text: str, anonymize_mentions: bool = True) -> str: |
| """ |
| Preprocess text by anonymizing @mentions. |
| |
| The model was trained with anonymized mentions, so this preprocessing |
| is recommended for best performance. |
| |
| Args: |
| text: Input text to preprocess |
| anonymize_mentions: Whether to replace @mentions with @anonymized_account |
| |
| Returns: |
| Preprocessed text |
| """ |
| if anonymize_mentions: |
| text = re.sub(r"@\w+", "@anonymized_account", text) |
| return text |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Predict emotions in Polish text using ONNX model" |
| ) |
| parser.add_argument( |
| "text", |
| type=str, |
| help="Text to classify", |
| ) |
| parser.add_argument( |
| "--threshold", |
| type=float, |
| default=DEFAULT_THRESHOLD, |
| help=f"Classification threshold (default: {DEFAULT_THRESHOLD})", |
| ) |
| parser.add_argument( |
| "--no-anonymize", |
| action="store_true", |
| help="Disable @mention anonymization (not recommended)", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| choices=["cpu", "cuda"], |
| default="cpu", |
| help="Device for inference (default: cpu)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| """Main entry point.""" |
| args = parse_args() |
|
|
| |
| model_path = Path(__file__).parent.resolve() |
|
|
| print(f"Loading model from: {model_path}") |
|
|
| |
| provider = "CUDAExecutionProvider" if args.device == "cuda" else "CPUExecutionProvider" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(str(model_path)) |
| model = ORTModelForSequenceClassification.from_pretrained( |
| str(model_path), |
| provider=provider, |
| ) |
|
|
| |
| labels = [model.config.id2label[i] for i in range(model.config.num_labels)] |
|
|
| |
| anonymize = not args.no_anonymize |
| processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) |
|
|
| if anonymize and processed_text != args.text: |
| print(f"Preprocessed: {processed_text}") |
|
|
| print(f"\nInput: {args.text}\n") |
|
|
| |
| inputs = tokenizer( |
| processed_text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=8192, |
| ) |
|
|
| outputs = model(**inputs) |
| logits = outputs.logits.squeeze().numpy() |
|
|
| |
| probabilities = 1 / (1 + np.exp(-logits)) |
| predictions = probabilities > args.threshold |
|
|
| |
| assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] |
|
|
| |
| if assigned_labels: |
| print("Assigned Labels:") |
| print("-" * 40) |
| for label in assigned_labels: |
| print(f" {label}") |
| print() |
| else: |
| print("No labels assigned (all below threshold)\n") |
|
|
| print("All Labels (with probabilities):") |
| print("-" * 40) |
| for i, label in enumerate(labels): |
| status = "✓" if predictions[i] else " " |
| print(f"{status} {label:15s}: {probabilities[i]:.4f}") |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|