deberta-goemotions / pos_weight_predict.py
veryfansome's picture
chore: example predict script rename
17ccfca
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def predict_new_texts(texts, model_dir, use_per_label=True):
"""
Args:
texts (list[str]): A list of raw text strings to classify.
model_dir (str): The directory where your model + tokenizer + thresholds.json are saved.
use_per_label (bool): If True, apply per-label thresholds. If False, apply global threshold.
Returns:
A list of dicts, each with {"text": ..., "predicted_labels": [...]}
"""
# 1) Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# 2) Option A: Load thresholds from model.config
label_names = getattr(model.config, "label_names", None)
per_label_thresh = getattr(model.config, "per_label_thresholds", None)
global_thresh = getattr(model.config, "best_global_threshold", 0.5)
# Make sure we have some default fallback if thresholds or label_names are not found
if label_names is None:
raise ValueError("No 'label_names' found in config or thresholds.json.")
if per_label_thresh is None:
# fallback to uniform threshold array if not found
per_label_thresh = [global_thresh] * model.config.num_labels
# 3) Tokenize the input texts
model.eval()
encodings = tokenizer(
texts,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
# 4) Run the model to get logits
with torch.no_grad():
outputs = model(**encodings)
logits = outputs.logits # shape: (batch_size, num_labels)
probs = torch.sigmoid(logits).cpu().numpy() # shape: (batch_size, num_labels)
# 5) Determine predictions by thresholding
if use_per_label:
# Use per-label thresholds
threshold_array = np.array(per_label_thresh)
preds = (probs >= threshold_array).astype(int) # shape: (batch_size, num_labels)
else:
# Use global threshold
preds = (probs >= global_thresh).astype(int)
# 6) Convert integer predictions to label names
results = []
for i, text in enumerate(texts):
row_preds = preds[i]
predicted_labels = [label_names[j] for j, val in enumerate(row_preds) if val == 1]
results.append({"text": text, "predicted_labels": predicted_labels})
return results
if __name__ == "__main__":
# Example usage
model_output_dir = "./models/pos_weight_best" # same as --output_dir in training
texts_to_classify = [
"I'm so happy today!",
"Everything is awful.",
"I can't believe how excited I am for the future!",
"I lost my job and I'm devastated."
]
pred_results = predict_new_texts(texts_to_classify, model_output_dir, use_per_label=True)
for item in pred_results:
print("Text:", item["text"])
print("Predicted labels:", item["predicted_labels"])
print("-----------")