| import torch |
| import numpy as np |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| def predict_new_texts(texts, model_dir, use_per_label=True): |
| """ |
| Args: |
| texts (list[str]): A list of raw text strings to classify. |
| model_dir (str): The directory where your model + tokenizer + thresholds.json are saved. |
| use_per_label (bool): If True, apply per-label thresholds. If False, apply global threshold. |
| Returns: |
| A list of dicts, each with {"text": ..., "predicted_labels": [...]} |
| """ |
|
|
| |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
| |
| label_names = getattr(model.config, "label_names", None) |
| per_label_thresh = getattr(model.config, "per_label_thresholds", None) |
| global_thresh = getattr(model.config, "best_global_threshold", 0.5) |
|
|
| |
| if label_names is None: |
| raise ValueError("No 'label_names' found in config or thresholds.json.") |
| if per_label_thresh is None: |
| |
| per_label_thresh = [global_thresh] * model.config.num_labels |
|
|
| |
| model.eval() |
| encodings = tokenizer( |
| texts, |
| truncation=True, |
| padding=True, |
| max_length=512, |
| return_tensors="pt" |
| ) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**encodings) |
| logits = outputs.logits |
| probs = torch.sigmoid(logits).cpu().numpy() |
|
|
| |
| if use_per_label: |
| |
| threshold_array = np.array(per_label_thresh) |
| preds = (probs >= threshold_array).astype(int) |
| else: |
| |
| preds = (probs >= global_thresh).astype(int) |
|
|
| |
| results = [] |
| for i, text in enumerate(texts): |
| row_preds = preds[i] |
| predicted_labels = [label_names[j] for j, val in enumerate(row_preds) if val == 1] |
| results.append({"text": text, "predicted_labels": predicted_labels}) |
|
|
| return results |
|
|
| if __name__ == "__main__": |
| |
| model_output_dir = "./models/pos_weight_best" |
|
|
| texts_to_classify = [ |
| "I'm so happy today!", |
| "Everything is awful.", |
| "I can't believe how excited I am for the future!", |
| "I lost my job and I'm devastated." |
| ] |
|
|
| pred_results = predict_new_texts(texts_to_classify, model_output_dir, use_per_label=True) |
| for item in pred_results: |
| print("Text:", item["text"]) |
| print("Predicted labels:", item["predicted_labels"]) |
| print("-----------") |
|
|