Kalaoke commited on
Commit
05f6b94
·
1 Parent(s): c268aa4

add custom handler and modify pipeline

Browse files
__pycache__/bert_for_sequence_classification.cpython-37.pyc CHANGED
Binary files a/__pycache__/bert_for_sequence_classification.cpython-37.pyc and b/__pycache__/bert_for_sequence_classification.cpython-37.pyc differ
 
__pycache__/bibert_multitask_classification.cpython-37.pyc CHANGED
Binary files a/__pycache__/bibert_multitask_classification.cpython-37.pyc and b/__pycache__/bibert_multitask_classification.cpython-37.pyc differ
 
__pycache__/handler.cpython-37.pyc CHANGED
Binary files a/__pycache__/handler.cpython-37.pyc and b/__pycache__/handler.cpython-37.pyc differ
 
handler.py CHANGED
@@ -8,9 +8,14 @@ from bibert_multitask_classification import BiBert_MultiTaskPipeline
8
  from bert_for_sequence_classification import BertForSequenceClassification
9
  from transformers.utils import logging
10
 
 
 
 
11
  logging.set_verbosity_info()
12
  logger = logging.get_logger("transformers")
13
 
 
 
14
 
15
  @dataclass
16
  class Task:
@@ -19,20 +24,16 @@ class Task:
19
  type: str
20
  num_labels: int
21
 
 
 
 
 
22
 
23
  class EndpointHandler():
24
  def __init__(self, path=""):
25
  # Preload all the elements you are going to need at inference.
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  tokenizer = AutoTokenizer.from_pretrained(path)
28
 
29
- PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
30
- tasks = [
31
- Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
32
- Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
33
- ]
34
-
35
-
36
  model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
37
 
38
  self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)
 
8
  from bert_for_sequence_classification import BertForSequenceClassification
9
  from transformers.utils import logging
10
 
11
+
12
+ PIPELINE_REGISTRY.register_pipeline("bibert-multitask-classification", pipeline_class=BiBert_MultiTaskPipeline, pt_model=BertForSequenceClassification)
13
+
14
  logging.set_verbosity_info()
15
  logger = logging.get_logger("transformers")
16
 
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
 
20
  @dataclass
21
  class Task:
 
24
  type: str
25
  num_labels: int
26
 
27
+ tasks = [
28
+ Task(id=0, name='label_classification', type='seq_classification', num_labels=5),
29
+ Task(id=1, name='binary_classification', type='seq_classification', num_labels=2)
30
+ ]
31
 
32
  class EndpointHandler():
33
  def __init__(self, path=""):
34
  # Preload all the elements you are going to need at inference.
 
35
  tokenizer = AutoTokenizer.from_pretrained(path)
36
 
 
 
 
 
 
 
 
37
  model = BertForSequenceClassification.from_pretrained(path, tasks_map=tasks).to(device)
38
 
39
  self.classifier_s = pipeline("bibert-multitask-classification", model = model, task_id="0", tokenizer=tokenizer, device = device)