Register pipeline
Browse files- register_pipeline.py +13 -0
register_pipeline.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline, AutoModelForSequenceClassification
|
| 2 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
| 3 |
+
|
| 4 |
+
from bert_paper_classifier import SciBertPaperClassifierPipeline
|
| 5 |
+
|
| 6 |
+
PIPELINE_REGISTRY.register_pipeline(
|
| 7 |
+
"text-classification",
|
| 8 |
+
pipeline_class=SciBertPaperClassifierPipeline,
|
| 9 |
+
pt_model=AutoModelForSequenceClassification,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
pipe = pipeline(task="paper-classification", model="HibiscusMaximus/scibert_paper_classification")
|
| 13 |
+
pipe.push_to_hub("bert-paper-classification-pipeline")
|