Token Classification
Transformers
Safetensors
MultiLabelBert
multilabel
multilabel-token-classification
custom_code
Instructions to use jvaquet/multilabel-classification-bert-ace2004 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use jvaquet/multilabel-classification-bert-ace2004 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="jvaquet/multilabel-classification-bert-ace2004", trust_remote_code=True)# Load model directly from transformers import AutoModelForTokenClassification model = AutoModelForTokenClassification.from_pretrained("jvaquet/multilabel-classification-bert-ace2004", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 7,177 Bytes
8804863 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | from transformers import Pipeline
import torch
import torch.nn as nn
MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION = [
'BertForMultiLabelTokenClassification'
]
class MultilabelNerPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION)
self.entity_types = {label[2:] for label in self.model.config.label2id}
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if 'stride' in kwargs:
preprocess_kwargs['stride'] = kwargs['stride']
postprocess_kwargs = {}
if 'threshold' in kwargs:
postprocess_kwargs['threshold'] = kwargs['threshold']
if 'use_hierarchy_heuristic' in kwargs:
postprocess_kwargs['use_hierarchy_heuristic'] = kwargs['use_hierarchy_heuristic']
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, stride=128):
tokenized_inputs = self.tokenizer(inputs,
truncation=True,
padding=True,
stride=stride,
return_tensors='pt',
return_overflowing_tokens=True,
return_special_tokens_mask=True
)
n_samples = tokenized_inputs.input_ids.size()[0]
char_offsets = [tokenized_inputs[idx].offsets for idx in range(n_samples)]
return {
'input_ids': tokenized_inputs.input_ids,
'attention_mask': tokenized_inputs.attention_mask,
'char_offsets': char_offsets,
'special_tokens_mask': tokenized_inputs.special_tokens_mask,
'text': inputs
}
def _forward(self, model_inputs):
return {
'logits': self.model(**model_inputs).logits,
'text': model_inputs['text'],
'char_offsets': model_inputs['char_offsets'],
'special_tokens_mask': model_inputs['special_tokens_mask']
}
def postprocess(self, model_outputs, threshold=0.5, use_hierarchy_heuristic=False):
predictions = nn.functional.sigmoid(model_outputs['logits'])
predictions[model_outputs['special_tokens_mask'] == 1] = 0
spans_single = self.extract_single_token_spans(predictions, threshold)
spans_multi = self.extract_multi_token_spans(predictions, threshold)
spans = self.token_spans_to_char_spans(spans_single + spans_multi, model_outputs['char_offsets'], model_outputs['text'])
spans = self.deduplicate_spans(spans)
if use_hierarchy_heuristic:
spans = self.apply_hierarchy_heristic(spans)
return spans
def extract_single_token_spans(self, predictions, threshold):
return [{
'label': entity_type,
'batch': idx_batch,
'span_token': (int(idx_token), int(idx_token+1))
}
for entity_type in self.entity_types
for idx_batch, idx_token in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'S-{entity_type}']] >= threshold))
]
def extract_multi_token_spans(self, predictions, threshold):
return [{
'label': entity_type,
'batch': idx_batch_begin,
'span_token': (int(idx_token_begin), int(idx_token_end+1))
}
for entity_type in self.entity_types
for idx_batch_begin, idx_token_begin in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'B-{entity_type}']] >= threshold))
for idx_batch_end, idx_token_end in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'E-{entity_type}']] >= threshold))
if idx_batch_begin == idx_batch_end
if idx_token_begin < idx_token_end
if torch.all(predictions[idx_batch_begin, idx_token_begin+1:idx_token_end, self.model.config.label2id[f'I-{entity_type}']] >= threshold)
]
def token_spans_to_char_spans(self, spans, char_offsets, text):
return [{
'label': span['label'],
'span': (char_start, char_end),
'text': text[char_start:char_end]
}
for span in spans
if (batch := span['batch']) is not None
if (span_token := span['span_token']) is not None
if (char_start := char_offsets[batch][span_token[0]][0]) is not None
if (char_end := char_offsets[batch][span_token[1]-1][1]) is not None]
def deduplicate_spans(self, spans):
return [dict(tup)
for tup in {tuple(span.items()) for span in spans}
]
def apply_hierarchy_heristic(self, spans):
def _group_spans(spans):
groups = []
for span in sorted(spans, key=lambda span: span['span'][0] - span['span'][1]):
found_group = False
for cur_group in groups:
if (cur_group['label'] == span['label']
and cur_group['start'] <= span['span'][0]
and cur_group['end'] >= span['span'][1]):
cur_group['spans'].append(span)
found_group = True
break
# If no group found, make new one
if not found_group:
groups.append({
'start': span['span'][0],
'end': span['span'][1],
'spans': [span],
'label': span['label']
})
return groups
return_spans = []
for group in _group_spans(spans):
sorted_spans = sorted(group['spans'], key=lambda span: span['span'][1] - span['span'][0])
# Collect all start and end positions
span_starts = {span['span'][0] for span in sorted_spans}
span_ends = {span['span'][1] for span in sorted_spans}
# Except for start and end of group
span_starts.discard(sorted_spans[-1]['span'][0])
span_ends.discard(sorted_spans[-1]['span'][1])
# Preserve encapsulating span
cur_spans = [sorted_spans[-1]]
# Iteratively add shortest span, if it covers an unused start or end point
for cur_span in sorted_spans[:-1]:
if len(span_starts) + len(span_ends) == 0:
break
if cur_span['span'][0] in span_starts \
or cur_span['span'][1] in span_ends:
cur_spans.append(cur_span)
span_starts.discard(cur_span['span'][0])
span_ends.discard(cur_span['span'][1])
return_spans += cur_spans
return return_spans
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForTokenClassification
PIPELINE_REGISTRY.register_pipeline(
'multilabel-ner',
pipeline_class=MultilabelNerPipeline,
pt_model=AutoModelForTokenClassification,
default={'pt': ('jvaquet/multilabel-classification-bert', 'main')},
type='text',
) |