twitter-emotion-pl-classifier-ONNX / predict_calibrated.py
yazoniak's picture
ONNX model uploaded
6970198 verified
#!/usr/bin/env python3
"""
Calibrated inference script for Polish Twitter Emotion Classifier (ONNX).
Predicts emotions and sentiment using temperature scaling and optimal
per-label thresholds for improved accuracy. This is the recommended
inference method.
Usage:
python predict_calibrated.py "Your Polish text here"
Requirements:
pip install optimum[onnxruntime] transformers numpy
"""
import argparse
import json
import re
import sys
from pathlib import Path
import numpy as np
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
def preprocess_text(text: str, anonymize_mentions: bool = True) -> str:
"""
Preprocess text by anonymizing @mentions.
The model was trained with anonymized mentions, so this preprocessing
is recommended for best performance.
Args:
text: Input text to preprocess
anonymize_mentions: Whether to replace @mentions with @anonymized_account
Returns:
Preprocessed text
"""
if anonymize_mentions:
text = re.sub(r"@\w+", "@anonymized_account", text)
return text
def load_calibration_artifacts(model_path: Path) -> dict:
"""
Load calibration artifacts from JSON file.
Args:
model_path: Path to model directory containing calibration_artifacts.json
Returns:
Dictionary with temperatures and optimal_thresholds
Raises:
FileNotFoundError: If calibration file not found
"""
calib_path = model_path / "calibration_artifacts.json"
if not calib_path.exists():
raise FileNotFoundError(f"Calibration artifacts not found: {calib_path}")
with open(calib_path, "r") as f:
return json.load(f)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Predict emotions using calibrated ONNX model (recommended)"
)
parser.add_argument(
"text",
type=str,
help="Text to classify",
)
parser.add_argument(
"--no-anonymize",
action="store_true",
help="Disable @mention anonymization (not recommended)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cpu",
help="Device for inference (default: cpu)",
)
return parser.parse_args()
def main() -> int:
"""Main entry point."""
args = parse_args()
# Model path is the same directory as this script
model_path = Path(__file__).parent.resolve()
print(f"Loading model from: {model_path}")
# Select execution provider
provider = "CUDAExecutionProvider" if args.device == "cuda" else "CPUExecutionProvider"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(str(model_path))
model = ORTModelForSequenceClassification.from_pretrained(
str(model_path),
provider=provider,
)
# Get labels from model config
labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
# Load calibration artifacts
calib_artifacts = load_calibration_artifacts(model_path)
temperatures = calib_artifacts["temperatures"]
optimal_thresholds = calib_artifacts["optimal_thresholds"]
# Preprocess text
anonymize = not args.no_anonymize
processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
if anonymize and processed_text != args.text:
print(f"Preprocessed: {processed_text}")
print(f"\nInput: {args.text}\n")
# Tokenize and run inference
inputs = tokenizer(
processed_text,
return_tensors="pt",
truncation=True,
max_length=8192,
)
outputs = model(**inputs)
logits = outputs.logits.squeeze().numpy()
# Apply temperature scaling and optimal thresholds
calibrated_probs = []
predictions = []
for i, label in enumerate(labels):
temp = temperatures[label]
threshold = optimal_thresholds[label]
# Temperature scaling
calibrated_logit = logits[i] / temp
prob = 1 / (1 + np.exp(-calibrated_logit))
calibrated_probs.append(prob)
predictions.append(prob > threshold)
# Get assigned labels
assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
# Print results
if assigned_labels:
print("Assigned Labels (Calibrated):")
print("-" * 40)
for label in assigned_labels:
print(f" {label}")
print()
else:
print("No labels assigned (all below optimal thresholds)\n")
print("All Labels (with calibrated probabilities):")
print("-" * 40)
for i, label in enumerate(labels):
status = "✓" if predictions[i] else " "
threshold = optimal_thresholds[label]
print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})")
return 0
if __name__ == "__main__":
sys.exit(main())