grano1 commited on
Commit
1019f0c
·
verified ·
1 Parent(s): 592f342

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +26 -1
README.md CHANGED
@@ -84,8 +84,33 @@ For **26 parent subjects**, F1-score improves to **0.934** with full metadata.
84
  ## 🔍 Example Usage
85
 
86
  ```python
87
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
89
  pipe = pipeline("text-classification", model="asjc-classification/scibert_multilabel_asjc_classifier")
90
 
91
  text = (
 
84
  ## 🔍 Example Usage
85
 
86
  ```python
87
+ from transformers import TextClassificationPipeline, pipeline
88
+ import torch
89
+
90
+ class ASJCMultiLabelPipeline(TextClassificationPipeline):
91
+ def __init__(self, *args, **kwargs):
92
+ self.threshold = kwargs.pop("threshold", None)
93
+ super().__init__(*args, **kwargs)
94
+
95
+ # If no explicit threshold passed → use threshold from config.json
96
+ if self.threshold is None:
97
+ self.threshold = getattr(self.model.config, "threshold", 0.3)
98
+
99
+ def postprocess(self, model_outputs, **kwargs):
100
+ scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
101
+ results = []
102
+
103
+ for i, score in enumerate(scores):
104
+ if score >= self.threshold:
105
+ label = self.model.config.id2label[str(i)]
106
+ results.append({"label": label, "score": float(score)})
107
+
108
+ # Sort by score descending
109
+ results = sorted(results, key=lambda x: x["score"], reverse=True)
110
+ return results
111
+ ```
112
 
113
+ ```python
114
  pipe = pipeline("text-classification", model="asjc-classification/scibert_multilabel_asjc_classifier")
115
 
116
  text = (