jvaquet's picture
Upload MultilabelNerPipeline
8804863 verified
from transformers import Pipeline
import torch
import torch.nn as nn
MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION = [
'BertForMultiLabelTokenClassification'
]
class MultilabelNerPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_MULTILABEL_TOKEN_CLASSIFICATION)
self.entity_types = {label[2:] for label in self.model.config.label2id}
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if 'stride' in kwargs:
preprocess_kwargs['stride'] = kwargs['stride']
postprocess_kwargs = {}
if 'threshold' in kwargs:
postprocess_kwargs['threshold'] = kwargs['threshold']
if 'use_hierarchy_heuristic' in kwargs:
postprocess_kwargs['use_hierarchy_heuristic'] = kwargs['use_hierarchy_heuristic']
return preprocess_kwargs, {}, postprocess_kwargs
def preprocess(self, inputs, stride=128):
tokenized_inputs = self.tokenizer(inputs,
truncation=True,
padding=True,
stride=stride,
return_tensors='pt',
return_overflowing_tokens=True,
return_special_tokens_mask=True
)
n_samples = tokenized_inputs.input_ids.size()[0]
char_offsets = [tokenized_inputs[idx].offsets for idx in range(n_samples)]
return {
'input_ids': tokenized_inputs.input_ids,
'attention_mask': tokenized_inputs.attention_mask,
'char_offsets': char_offsets,
'special_tokens_mask': tokenized_inputs.special_tokens_mask,
'text': inputs
}
def _forward(self, model_inputs):
return {
'logits': self.model(**model_inputs).logits,
'text': model_inputs['text'],
'char_offsets': model_inputs['char_offsets'],
'special_tokens_mask': model_inputs['special_tokens_mask']
}
def postprocess(self, model_outputs, threshold=0.5, use_hierarchy_heuristic=False):
predictions = nn.functional.sigmoid(model_outputs['logits'])
predictions[model_outputs['special_tokens_mask'] == 1] = 0
spans_single = self.extract_single_token_spans(predictions, threshold)
spans_multi = self.extract_multi_token_spans(predictions, threshold)
spans = self.token_spans_to_char_spans(spans_single + spans_multi, model_outputs['char_offsets'], model_outputs['text'])
spans = self.deduplicate_spans(spans)
if use_hierarchy_heuristic:
spans = self.apply_hierarchy_heristic(spans)
return spans
def extract_single_token_spans(self, predictions, threshold):
return [{
'label': entity_type,
'batch': idx_batch,
'span_token': (int(idx_token), int(idx_token+1))
}
for entity_type in self.entity_types
for idx_batch, idx_token in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'S-{entity_type}']] >= threshold))
]
def extract_multi_token_spans(self, predictions, threshold):
return [{
'label': entity_type,
'batch': idx_batch_begin,
'span_token': (int(idx_token_begin), int(idx_token_end+1))
}
for entity_type in self.entity_types
for idx_batch_begin, idx_token_begin in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'B-{entity_type}']] >= threshold))
for idx_batch_end, idx_token_end in zip(*torch.where(predictions[:,:, self.model.config.label2id[f'E-{entity_type}']] >= threshold))
if idx_batch_begin == idx_batch_end
if idx_token_begin < idx_token_end
if torch.all(predictions[idx_batch_begin, idx_token_begin+1:idx_token_end, self.model.config.label2id[f'I-{entity_type}']] >= threshold)
]
def token_spans_to_char_spans(self, spans, char_offsets, text):
return [{
'label': span['label'],
'span': (char_start, char_end),
'text': text[char_start:char_end]
}
for span in spans
if (batch := span['batch']) is not None
if (span_token := span['span_token']) is not None
if (char_start := char_offsets[batch][span_token[0]][0]) is not None
if (char_end := char_offsets[batch][span_token[1]-1][1]) is not None]
def deduplicate_spans(self, spans):
return [dict(tup)
for tup in {tuple(span.items()) for span in spans}
]
def apply_hierarchy_heristic(self, spans):
def _group_spans(spans):
groups = []
for span in sorted(spans, key=lambda span: span['span'][0] - span['span'][1]):
found_group = False
for cur_group in groups:
if (cur_group['label'] == span['label']
and cur_group['start'] <= span['span'][0]
and cur_group['end'] >= span['span'][1]):
cur_group['spans'].append(span)
found_group = True
break
# If no group found, make new one
if not found_group:
groups.append({
'start': span['span'][0],
'end': span['span'][1],
'spans': [span],
'label': span['label']
})
return groups
return_spans = []
for group in _group_spans(spans):
sorted_spans = sorted(group['spans'], key=lambda span: span['span'][1] - span['span'][0])
# Collect all start and end positions
span_starts = {span['span'][0] for span in sorted_spans}
span_ends = {span['span'][1] for span in sorted_spans}
# Except for start and end of group
span_starts.discard(sorted_spans[-1]['span'][0])
span_ends.discard(sorted_spans[-1]['span'][1])
# Preserve encapsulating span
cur_spans = [sorted_spans[-1]]
# Iteratively add shortest span, if it covers an unused start or end point
for cur_span in sorted_spans[:-1]:
if len(span_starts) + len(span_ends) == 0:
break
if cur_span['span'][0] in span_starts \
or cur_span['span'][1] in span_ends:
cur_spans.append(cur_span)
span_starts.discard(cur_span['span'][0])
span_ends.discard(cur_span['span'][1])
return_spans += cur_spans
return return_spans
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForTokenClassification
PIPELINE_REGISTRY.register_pipeline(
'multilabel-ner',
pipeline_class=MultilabelNerPipeline,
pt_model=AutoModelForTokenClassification,
default={'pt': ('jvaquet/multilabel-classification-bert', 'main')},
type='text',
)