| import argparse | |
| import json | |
| import os | |
| import re | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| def preprocess_text(text, anonymize_mentions=True): | |
| if anonymize_mentions: | |
| text = re.sub(r'@\w+', '@anonymized_account', text) | |
| return text | |
| def load_calibration_artifacts(calib_path): | |
| if not os.path.exists(calib_path): | |
| raise FileNotFoundError(f"Calibration artifacts not found at: {calib_path}") | |
| with open(calib_path, 'r') as f: | |
| return json.load(f) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("text", type=str, help="Text to classify") | |
| parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", | |
| help="Path to model or HF model ID") | |
| parser.add_argument("--calibration-path", type=str, default=None, | |
| help="Path to calibration_artifacts.json (default: auto-detect)") | |
| parser.add_argument("--no-anonymize", action="store_true", | |
| help="Disable mention anonymization (not recommended)") | |
| args = parser.parse_args() | |
| print(f"Loading model from: {args.model_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(args.model_path) | |
| model.eval() | |
| labels = [model.config.id2label[i] for i in range(model.config.num_labels)] | |
| anonymize = not args.no_anonymize | |
| processed_text = preprocess_text(args.text, anonymize_mentions=anonymize) | |
| if anonymize and processed_text != args.text: | |
| print(f"Preprocessed text: {processed_text}") | |
| if args.calibration_path: | |
| calib_path = args.calibration_path | |
| elif os.path.isdir(args.model_path): | |
| calib_path = os.path.join(args.model_path, "calibration_artifacts.json") | |
| else: | |
| calib_path = os.path.join(os.path.dirname(__file__), "calibration_artifacts.json") | |
| print(f"Loading calibration from: {calib_path}") | |
| calib_artifacts = load_calibration_artifacts(calib_path) | |
| temperatures = calib_artifacts["temperatures"] | |
| optimal_thresholds = calib_artifacts["optimal_thresholds"] | |
| print(f"\nInput text: {args.text}\n") | |
| inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits.squeeze().numpy() | |
| calibrated_probs = [] | |
| predictions = [] | |
| for i, label in enumerate(labels): | |
| temp = temperatures[label] | |
| threshold = optimal_thresholds[label] | |
| calibrated_logit = logits[i] / temp | |
| prob = 1 / (1 + np.exp(-calibrated_logit)) | |
| calibrated_probs.append(prob) | |
| predictions.append(prob > threshold) | |
| assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]] | |
| if assigned_labels: | |
| print("Assigned Labels (Calibrated):") | |
| print("-" * 40) | |
| for label in assigned_labels: | |
| print(f" {label}") | |
| print() | |
| else: | |
| print("No labels assigned (all below optimal thresholds)\n") | |
| print("All Labels (with calibrated probabilities):") | |
| print("-" * 40) | |
| for i, label in enumerate(labels): | |
| status = "✓" if predictions[i] else " " | |
| threshold = optimal_thresholds[label] | |
| print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})") | |
| if __name__ == "__main__": | |
| main() | |