grano1 commited on
Commit
af63e06
·
verified ·
1 Parent(s): 71c969a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +12 -23
README.md CHANGED
@@ -87,24 +87,16 @@ For **26 parent subjects**, F1-score improves to **0.934** with full metadata.
87
  from transformers import TextClassificationPipeline, pipeline
88
  import torch
89
 
90
- # Define the Custom Pipeline
91
  class ASJCMultiLabelPipeline(TextClassificationPipeline):
92
  """
93
- Custom pipeline for multi-label ASJC classification.
94
-
95
- This pipeline:
96
- - Applies sigmoid to the model logits.
97
- - Filters labels by a threshold.
98
- - Returns all labels with scores above the threshold.
99
-
100
- Threshold can be specified during pipeline creation.
101
- If not provided, it defaults to the `threshold` in the model's config.json, or 0.3.
102
  """
103
  def __init__(self, *args, **kwargs):
 
104
  self.threshold = kwargs.pop("threshold", None)
105
  super().__init__(*args, **kwargs)
106
-
107
- # Use threshold from config if none is passed explicitly
108
  if self.threshold is None:
109
  self.threshold = getattr(self.model.config, "threshold", 0.3)
110
 
@@ -113,39 +105,36 @@ class ASJCMultiLabelPipeline(TextClassificationPipeline):
113
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
114
  results = []
115
 
116
- # Collect labels above the threshold
117
  for i, score in enumerate(scores):
118
  if score >= self.threshold:
119
  label = self.model.config.id2label[str(i)]
120
  results.append({"label": label, "score": float(score)})
121
 
122
- # Sort results by descending probability
123
  results = sorted(results, key=lambda x: x["score"], reverse=True)
124
  return results
125
- ```
126
 
127
- ```python
128
- # Create pipeline with the multi-label model
129
  pipe = pipeline(
130
- "text-classification",
131
- model="asjc-classification/scibert_multilabel_asjc_classifier"
 
132
  )
133
 
134
- # Example text input (title, container_title, abstract)
135
  text = (
136
  "title={Jodometrie}, "
137
  "container_title={Fresenius' Zeitschrift für analytische Chemie, Zeitschrift für analytische Chemie}, "
138
  "abstract={}"
139
  )
140
 
141
- # Get predictions
142
  result = pipe(text)
143
  print(result)
144
 
145
- # Expected labels (based on actual ASJC categories):
146
  # - Clinical Biochemistry
147
  # - Analytical Chemistry
148
-
149
  ```
150
 
151
  ---
 
87
  from transformers import TextClassificationPipeline, pipeline
88
  import torch
89
 
90
+ # --- Custom multi-label pipeline ---
91
  class ASJCMultiLabelPipeline(TextClassificationPipeline):
92
  """
93
+ Multi-label classification pipeline for ASJC categories.
94
+ Uses a configurable threshold to return all labels with scores above the threshold.
 
 
 
 
 
 
 
95
  """
96
  def __init__(self, *args, **kwargs):
97
+ # Allow threshold override; default falls back to model config
98
  self.threshold = kwargs.pop("threshold", None)
99
  super().__init__(*args, **kwargs)
 
 
100
  if self.threshold is None:
101
  self.threshold = getattr(self.model.config, "threshold", 0.3)
102
 
 
105
  scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
106
  results = []
107
 
 
108
  for i, score in enumerate(scores):
109
  if score >= self.threshold:
110
  label = self.model.config.id2label[str(i)]
111
  results.append({"label": label, "score": float(score)})
112
 
113
+ # Sort by descending score
114
  results = sorted(results, key=lambda x: x["score"], reverse=True)
115
  return results
 
116
 
117
+ # --- Create the pipeline explicitly using the custom class ---
 
118
  pipe = pipeline(
119
+ task="text-classification",
120
+ model="asjc-classification/scibert_multilabel_asjc_classifier",
121
+ pipeline_class=ASJCMultiLabelPipeline
122
  )
123
 
124
+ # --- Example text input ---
125
  text = (
126
  "title={Jodometrie}, "
127
  "container_title={Fresenius' Zeitschrift für analytische Chemie, Zeitschrift für analytische Chemie}, "
128
  "abstract={}"
129
  )
130
 
131
+ # --- Get multi-label predictions ---
132
  result = pipe(text)
133
  print(result)
134
 
135
+ # Expected labels:
136
  # - Clinical Biochemistry
137
  # - Analytical Chemistry
 
138
  ```
139
 
140
  ---