| |
| """ |
| Calibrated inference script for Polish Twitter Emotion Classifier (ONNX). |
| |
| Predicts emotions and sentiment using temperature scaling and optimal |
| per-label thresholds for improved accuracy. This is the recommended |
| inference method. |
| |
| Usage: |
| python predict_calibrated.py "Your Polish text here" |
| |
| Requirements: |
| pip install optimum[onnxruntime] transformers numpy |
| """ |
|
|
| import argparse |
| import json |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| from optimum.onnxruntime import ORTModelForSequenceClassification |
| from transformers import AutoTokenizer |
|
|
|
|
| def preprocess_text(text: str, anonymize_mentions: bool = True) -> str: |
| """ |
| Preprocess text by anonymizing @mentions. |
| |
| The model was trained with anonymized mentions, so this preprocessing |
| is recommended for best performance. |
| |
| Args: |
| text: Input text to preprocess |
| anonymize_mentions: Whether to replace @mentions with @anonymized_account |
| |
| Returns: |
| Preprocessed text |
| """ |
| if anonymize_mentions: |
| text = re.sub(r"@\w+", "@anonymized_account", text) |
| return text |
|
|
|
|
| def load_calibration_artifacts(model_path: Path) -> dict: |
| """ |
| Load calibration artifacts from JSON file. |
| |
| Args: |
| model_path: Path to model directory containing calibration_artifacts.json |
| |
| Returns: |
| Dictionary with temperatures and optimal_thresholds |
| |
| Raises: |
| FileNotFoundError: If calibration file not found |
| """ |
| calib_path = model_path / "calibration_artifacts.json" |
| if not calib_path.exists(): |
| raise FileNotFoundError(f"Calibration artifacts not found: {calib_path}") |
|
|
| with open(calib_path, "r") as f: |
| return json.load(f) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Predict emotions using calibrated ONNX model (recommended)" |
| ) |
| parser.add_argument( |
| "text", |
| type=str, |
| help="Text to classify", |
| ) |
| parser.add_argument( |
| "--no-anonymize", |
| action="store_true", |
| help="Disable @mention anonymization (not recommended)", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| choices=["cpu", "cuda"], |
| default="cpu", |
| help="Device for inference (default: cpu)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| """Main entry point.""" |
| args = parse_args() |
|
|
| |
| model_path = Path(__file__).parent.resolve() |
|
|
| print(f"Loading model from: {model_path}") |
|
|
| |
| provider = "CUDAExecutionProvider" if args.device == "cuda" else "CPUExecutionProvider" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(str(model_path)) |
| model = ORTModelForSequenceClassification.from_pretrained( |
| str(model_path), |
| provider=provider, |
| ) |
|
|
| |
| labels = [model.config.id2label[i] for i in range(model.config.num_labels)] |
|
|
| |
| calib_artifacts = load_calibration_artifacts(model_path) |
| temperatures = calib_artifacts["temperatures"] |
| optimal_thresholds = calib_artifacts["optimal_thresholds"] |
|
|
| |
| anonymize = not args.no_anonymize |
| processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) |
|
|
| if anonymize and processed_text != args.text: |
| print(f"Preprocessed: {processed_text}") |
|
|
| print(f"\nInput: {args.text}\n") |
|
|
| |
| inputs = tokenizer( |
| processed_text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=8192, |
| ) |
|
|
| outputs = model(**inputs) |
| logits = outputs.logits.squeeze().numpy() |
|
|
| |
| calibrated_probs = [] |
| predictions = [] |
|
|
| for i, label in enumerate(labels): |
| temp = temperatures[label] |
| threshold = optimal_thresholds[label] |
|
|
| |
| calibrated_logit = logits[i] / temp |
| prob = 1 / (1 + np.exp(-calibrated_logit)) |
|
|
| calibrated_probs.append(prob) |
| predictions.append(prob > threshold) |
|
|
| |
| assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] |
|
|
| |
| if assigned_labels: |
| print("Assigned Labels (Calibrated):") |
| print("-" * 40) |
| for label in assigned_labels: |
| print(f" {label}") |
| print() |
| else: |
| print("No labels assigned (all below optimal thresholds)\n") |
|
|
| print("All Labels (with calibrated probabilities):") |
| print("-" * 40) |
| for i, label in enumerate(labels): |
| status = "✓" if predictions[i] else " " |
| threshold = optimal_thresholds[label] |
| print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})") |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|