grano1's picture
Upload folder using huggingface_hub
8bdbd03 verified
raw
history blame
1.22 kB
'''
# @ Author: ASJC Team
# @ Create Time: 2025-11-25 23:52:55
# @ Modified by: ASJC Team
# @ Modified time: 2025-11-25 23:53:08
# @ Description: Custom pipeline for multi-label classification with fine-tuned SciBERT model.
'''
from transformers import TextClassificationPipeline
import torch
class ASJCMultiLabelPipeline(TextClassificationPipeline):
def __init__(self, *args, **kwargs):
self.threshold = kwargs.pop("threshold", None)
super().__init__(*args, **kwargs)
# If no explicit threshold passed → use threshold from config.json
if self.threshold is None:
self.threshold = getattr(self.model.config, "threshold", 0.3)
def postprocess(self, model_outputs, **kwargs):
scores = torch.sigmoid(torch.tensor(model_outputs["logits"])).tolist()
results = []
for i, score in enumerate(scores):
if score >= self.threshold:
label = self.model.config.id2label[str(i)]
results.append({"label": label, "score": float(score)})
# Sort by score descending
results = sorted(results, key=lambda x: x["score"], reverse=True)
return results