twitter-emotion-pl-classifier / predict_calibrated.py
yazoniak's picture
Repo initialized
7336cba verified
import argparse
import json
import os
import re
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def preprocess_text(text, anonymize_mentions=True):
if anonymize_mentions:
text = re.sub(r'@\w+', '@anonymized_account', text)
return text
def load_calibration_artifacts(calib_path):
if not os.path.exists(calib_path):
raise FileNotFoundError(f"Calibration artifacts not found at: {calib_path}")
with open(calib_path, 'r') as f:
return json.load(f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("text", type=str, help="Text to classify")
parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier",
help="Path to model or HF model ID")
parser.add_argument("--calibration-path", type=str, default=None,
help="Path to calibration_artifacts.json (default: auto-detect)")
parser.add_argument("--no-anonymize", action="store_true",
help="Disable mention anonymization (not recommended)")
args = parser.parse_args()
print(f"Loading model from: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModelForSequenceClassification.from_pretrained(args.model_path)
model.eval()
labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
anonymize = not args.no_anonymize
processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
if anonymize and processed_text != args.text:
print(f"Preprocessed text: {processed_text}")
if args.calibration_path:
calib_path = args.calibration_path
elif os.path.isdir(args.model_path):
calib_path = os.path.join(args.model_path, "calibration_artifacts.json")
else:
calib_path = os.path.join(os.path.dirname(__file__), "calibration_artifacts.json")
print(f"Loading calibration from: {calib_path}")
calib_artifacts = load_calibration_artifacts(calib_path)
temperatures = calib_artifacts["temperatures"]
optimal_thresholds = calib_artifacts["optimal_thresholds"]
print(f"\nInput text: {args.text}\n")
inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.squeeze().numpy()
calibrated_probs = []
predictions = []
for i, label in enumerate(labels):
temp = temperatures[label]
threshold = optimal_thresholds[label]
calibrated_logit = logits[i] / temp
prob = 1 / (1 + np.exp(-calibrated_logit))
calibrated_probs.append(prob)
predictions.append(prob > threshold)
assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
if assigned_labels:
print("Assigned Labels (Calibrated):")
print("-" * 40)
for label in assigned_labels:
print(f" {label}")
print()
else:
print("No labels assigned (all below optimal thresholds)\n")
print("All Labels (with calibrated probabilities):")
print("-" * 40)
for i, label in enumerate(labels):
status = "✓" if predictions[i] else " "
threshold = optimal_thresholds[label]
print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})")
if __name__ == "__main__":
main()