File size: 3,548 Bytes
7336cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import json
import os
import re
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def preprocess_text(text, anonymize_mentions=True):
    if anonymize_mentions:
        text = re.sub(r'@\w+', '@anonymized_account', text)
    return text

def load_calibration_artifacts(calib_path):
    if not os.path.exists(calib_path):
        raise FileNotFoundError(f"Calibration artifacts not found at: {calib_path}")
    
    with open(calib_path, 'r') as f:
        return json.load(f)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("text", type=str, help="Text to classify")
    parser.add_argument("--model-path", type=str, default="yazoniak/twitter-emotion-pl-classifier", 
                        help="Path to model or HF model ID")
    parser.add_argument("--calibration-path", type=str, default=None,
                        help="Path to calibration_artifacts.json (default: auto-detect)")
    parser.add_argument("--no-anonymize", action="store_true",
                        help="Disable mention anonymization (not recommended)")
    args = parser.parse_args()

    print(f"Loading model from: {args.model_path}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = AutoModelForSequenceClassification.from_pretrained(args.model_path)
    model.eval()
    
    labels = [model.config.id2label[i] for i in range(model.config.num_labels)]

    anonymize = not args.no_anonymize
    processed_text = preprocess_text(args.text, anonymize_mentions=anonymize)
    
    if anonymize and processed_text != args.text:
        print(f"Preprocessed text: {processed_text}")

    if args.calibration_path:
        calib_path = args.calibration_path
    elif os.path.isdir(args.model_path):
        calib_path = os.path.join(args.model_path, "calibration_artifacts.json")
    else:
        calib_path = os.path.join(os.path.dirname(__file__), "calibration_artifacts.json")
    
    print(f"Loading calibration from: {calib_path}")
    calib_artifacts = load_calibration_artifacts(calib_path)
    
    temperatures = calib_artifacts["temperatures"]
    optimal_thresholds = calib_artifacts["optimal_thresholds"]

    print(f"\nInput text: {args.text}\n")

    inputs = tokenizer(processed_text, return_tensors="pt", truncation=True, max_length=8192)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits.squeeze().numpy()
    
    calibrated_probs = []
    predictions = []
    
    for i, label in enumerate(labels):
        temp = temperatures[label]
        threshold = optimal_thresholds[label]
        
        calibrated_logit = logits[i] / temp
        prob = 1 / (1 + np.exp(-calibrated_logit))
        
        calibrated_probs.append(prob)
        predictions.append(prob > threshold)

    assigned_labels = [labels[i] for i in range(len(labels)) if predictions[i]]
    
    if assigned_labels:
        print("Assigned Labels (Calibrated):")
        print("-" * 40)
        for label in assigned_labels:
            print(f"  {label}")
        print()
    else:
        print("No labels assigned (all below optimal thresholds)\n")
    
    print("All Labels (with calibrated probabilities):")
    print("-" * 40)
    for i, label in enumerate(labels):
        status = "✓" if predictions[i] else " "
        threshold = optimal_thresholds[label]
        print(f"{status} {label:15s}: {calibrated_probs[i]:.4f} (threshold: {threshold:.2f})")

if __name__ == "__main__":
    main()