grano1 commited on
Commit
71c969a
·
verified ·
1 Parent(s): 1019f0c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +25 -4
README.md CHANGED
@@ -87,44 +87,65 @@ For **26 parent subjects**, F1-score improves to **0.934** with full metadata.
87
  from transformers import TextClassificationPipeline, pipeline
88
  import torch
89
 
 
90
  class ASJCMultiLabelPipeline(TextClassificationPipeline):
 
 
 
 
 
 
 
 
 
 
 
91
  def __init__(self, *args, **kwargs):
92
  self.threshold = kwargs.pop("threshold", None)
93
  super().__init__(*args, **kwargs)
94
 
95
- # If no explicit threshold passed use threshold from config.json
96
  if self.threshold is None:
97
  self.threshold = getattr(self.model.config, "threshold", 0.3)
98
 
99
  def postprocess(self, model_outputs, **kwargs):
 
100
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
101
  results = []
102
 
 
103
  for i, score in enumerate(scores):
104
  if score >= self.threshold:
105
  label = self.model.config.id2label[str(i)]
106
  results.append({"label": label, "score": float(score)})
107
 
108
- # Sort by score descending
109
  results = sorted(results, key=lambda x: x["score"], reverse=True)
110
  return results
111
  ```
112
 
113
  ```python
114
- pipe = pipeline("text-classification", model="asjc-classification/scibert_multilabel_asjc_classifier")
 
 
 
 
115
 
 
116
  text = (
117
  "title={Jodometrie}, "
118
  "container_title={Fresenius' Zeitschrift für analytische Chemie, Zeitschrift für analytische Chemie}, "
119
  "abstract={}"
120
  )
121
 
 
122
  result = pipe(text)
123
  print(result)
124
 
125
- # Actual labels:
126
  # - Clinical Biochemistry
127
  # - Analytical Chemistry
 
128
  ```
129
 
130
  ---
 
87
  from transformers import TextClassificationPipeline, pipeline
88
  import torch
89
 
90
+ # Define the Custom Pipeline
91
  class ASJCMultiLabelPipeline(TextClassificationPipeline):
92
+ """
93
+ Custom pipeline for multi-label ASJC classification.
94
+
95
+ This pipeline:
96
+ - Applies sigmoid to the model logits.
97
+ - Filters labels by a threshold.
98
+ - Returns all labels with scores above the threshold.
99
+
100
+ Threshold can be specified during pipeline creation.
101
+ If not provided, it defaults to the `threshold` in the model's config.json, or 0.3.
102
+ """
103
  def __init__(self, *args, **kwargs):
104
  self.threshold = kwargs.pop("threshold", None)
105
  super().__init__(*args, **kwargs)
106
 
107
+ # Use threshold from config if none is passed explicitly
108
  if self.threshold is None:
109
  self.threshold = getattr(self.model.config, "threshold", 0.3)
110
 
111
  def postprocess(self, model_outputs, **kwargs):
112
+ # Convert logits to probabilities using sigmoid
113
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
114
  results = []
115
 
116
+ # Collect labels above the threshold
117
  for i, score in enumerate(scores):
118
  if score >= self.threshold:
119
  label = self.model.config.id2label[str(i)]
120
  results.append({"label": label, "score": float(score)})
121
 
122
+ # Sort results by descending probability
123
  results = sorted(results, key=lambda x: x["score"], reverse=True)
124
  return results
125
  ```
126
 
127
  ```python
128
+ # Create pipeline with the multi-label model
129
+ pipe = pipeline(
130
+ "text-classification",
131
+ model="asjc-classification/scibert_multilabel_asjc_classifier"
132
+ )
133
 
134
+ # Example text input (title, container_title, abstract)
135
  text = (
136
  "title={Jodometrie}, "
137
  "container_title={Fresenius' Zeitschrift für analytische Chemie, Zeitschrift für analytische Chemie}, "
138
  "abstract={}"
139
  )
140
 
141
+ # Get predictions
142
  result = pipe(text)
143
  print(result)
144
 
145
+ # Expected labels (based on actual ASJC categories):
146
  # - Clinical Biochemistry
147
  # - Analytical Chemistry
148
+
149
  ```
150
 
151
  ---