yazoniak's picture
ONNX model uploaded
6970198 verified
#!/usr/bin/env python3
"""
Simple inference script for Polish Twitter Emotion Classifier (ONNX).
Predicts emotions and sentiment in Polish text using default threshold (0.5).
Usage:
python predict.py "Your Polish text here"
python predict.py "Text" --threshold 0.3
Requirements:
pip install optimum[onnxruntime] transformers numpy
"""
import argparse
import re
import sys
from pathlib import Path
import numpy as np
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
DEFAULT_THRESHOLD = 0.5
def preprocess_text(text: str, anonymize_mentions: bool = True) -> str:
"""
Preprocess text by anonymizing @mentions.
The model was trained with anonymized mentions, so this preprocessing
is recommended for best performance.
Args:
text: Input text to preprocess
anonymize_mentions: Whether to replace @mentions with @anonymized_account
Returns:
Preprocessed text
"""
if anonymize_mentions:
text = re.sub(r"@\w+", "@anonymized_account", text)
return text
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Predict emotions in Polish text using ONNX model"
)
parser.add_argument(
"text",
type=str,
help="Text to classify",
)
parser.add_argument(
"--threshold",
type=float,
default=DEFAULT_THRESHOLD,
help=f"Classification threshold (default: {DEFAULT_THRESHOLD})",
)
parser.add_argument(
"--no-anonymize",
action="store_true",
help="Disable @mention anonymization (not recommended)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
help="Device for inference (default: cpu)",
)
return parser.parse_args()
def main() -> int:
"""Main entry point."""
args = parse_args()
# Model path is the same directory as this script
model_path = Path(__file__).parent.resolve()
print(f"Loading model from: {model_path}")
# Select execution provider
provider = "CUDAExecutionProvider" if args.device == "cuda" else "CPUExecutionProvider"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(str(model_path))
model = ORTModelForSequenceClassification.from_pretrained(
str(model_path),
provider=provider,
)
# Get labels from model config
labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
# Preprocess text
anonymize = not args.no_anonymize
processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
if anonymize and processed_text != args.text:
print(f"Preprocessed: {processed_text}")
print(f"\nInput: {args.text}\n")
# Tokenize and run inference
inputs = tokenizer(
processed_text,
return_tensors="pt",
truncation=True,
max_length=8192,
)
outputs = model(**inputs)
logits = outputs.logits.squeeze().numpy()
# Calculate probabilities (sigmoid for multi-label)
probabilities = 1 / (1 + np.exp(-logits))
predictions = probabilities > args.threshold
# Get assigned labels
assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
# Print results
if assigned_labels:
print("Assigned Labels:")
print("-" * 40)
for label in assigned_labels:
print(f" {label}")
print()
else:
print("No labels assigned (all below threshold)\n")
print("All Labels (with probabilities):")
print("-" * 40)
for i, label in enumerate(labels):
status = "✓" if predictions[i] else " "
print(f"{status} {label:15s}: {probabilities[i]:.4f}")
return 0
if __name__ == "__main__":
sys.exit(main())