|
|
|
|
|
from transformers import Pipeline |
|
|
|
|
|
from src.lemmatize_helper import reconstruct_lemma |
|
|
|
|
|
|
|
|
class ConlluTokenClassificationPipeline(Pipeline): |
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
tokenizer: callable = None, |
|
|
sentenizer: callable = None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(model=model, **kwargs) |
|
|
self.tokenizer = tokenizer |
|
|
self.sentenizer = sentenizer |
|
|
|
|
|
|
|
|
def _sanitize_parameters(self, output_format: str = 'list', **kwargs): |
|
|
if output_format not in ['list', 'str']: |
|
|
raise ValueError( |
|
|
f"output_format must be 'str' or 'list', not {output_format}" |
|
|
) |
|
|
|
|
|
return {}, {}, {'output_format': output_format} |
|
|
|
|
|
|
|
|
def preprocess(self, inputs: str) -> dict: |
|
|
if not isinstance(inputs, str): |
|
|
raise ValueError("pipeline input must be string (text)") |
|
|
|
|
|
sentences = [sentence for sentence in self.sentenizer(inputs)] |
|
|
words = [ |
|
|
[word for word in self.tokenizer(sentence)] |
|
|
for sentence in sentences |
|
|
] |
|
|
|
|
|
self._texts = sentences |
|
|
return {"words": words} |
|
|
|
|
|
|
|
|
def _forward(self, model_inputs: dict) -> dict: |
|
|
return self.model(**model_inputs, inference_mode=True) |
|
|
|
|
|
|
|
|
def postprocess(self, model_outputs: dict, output_format: str) -> list[dict] | str: |
|
|
sentences = self._decode_model_output(model_outputs) |
|
|
|
|
|
if output_format == 'str': |
|
|
sentences = self._format_as_conllu(sentences) |
|
|
return sentences |
|
|
|
|
|
def _decode_model_output(self, model_outputs: dict) -> list[dict]: |
|
|
n_sentences = len(model_outputs["words"]) |
|
|
|
|
|
sentences_decoded = [] |
|
|
for i in range(n_sentences): |
|
|
|
|
|
def select_arcs(arcs, batch_idx): |
|
|
|
|
|
|
|
|
return arcs[arcs[:, 0] == batch_idx][:, 1:] |
|
|
|
|
|
|
|
|
n_words = len(model_outputs["words"][i]) |
|
|
|
|
|
optional_tags = {} |
|
|
if "lemma_rules" in model_outputs: |
|
|
optional_tags["lemma_rule_ids"] = model_outputs["lemma_rules"][i, :n_words].tolist() |
|
|
if "joint_feats" in model_outputs: |
|
|
optional_tags["joint_feats_ids"] = model_outputs["joint_feats"][i, :n_words].tolist() |
|
|
if "deps_ud" in model_outputs: |
|
|
optional_tags["deps_ud"] = select_arcs(model_outputs["deps_ud"], i).tolist() |
|
|
if "deps_eud" in model_outputs: |
|
|
optional_tags["deps_eud"] = select_arcs(model_outputs["deps_eud"], i).tolist() |
|
|
if "miscs" in model_outputs: |
|
|
optional_tags["misc_ids"] = model_outputs["miscs"][i, :n_words].tolist() |
|
|
if "deepslots" in model_outputs: |
|
|
optional_tags["deepslot_ids"] = model_outputs["deepslots"][i, :n_words].tolist() |
|
|
if "semclasses" in model_outputs: |
|
|
optional_tags["semclass_ids"] = model_outputs["semclasses"][i, :n_words].tolist() |
|
|
|
|
|
sentence_decoded = self._decode_sentence( |
|
|
text=self._texts[i], |
|
|
words=model_outputs["words"][i], |
|
|
**optional_tags, |
|
|
) |
|
|
sentences_decoded.append(sentence_decoded) |
|
|
return sentences_decoded |
|
|
|
|
|
def _decode_sentence( |
|
|
self, |
|
|
text: str, |
|
|
words: list[str], |
|
|
lemma_rule_ids: list[int] = None, |
|
|
joint_feats_ids: list[int] = None, |
|
|
deps_ud: list[list[int]] = None, |
|
|
deps_eud: list[list[int]] = None, |
|
|
misc_ids: list[int] = None, |
|
|
deepslot_ids: list[int] = None, |
|
|
semclass_ids: list[int] = None |
|
|
) -> dict: |
|
|
|
|
|
|
|
|
ids = self._enumerate_words(words) |
|
|
|
|
|
result = { |
|
|
"text": text, |
|
|
"words": words, |
|
|
"ids": ids |
|
|
} |
|
|
|
|
|
|
|
|
if lemma_rule_ids: |
|
|
result["lemmas"] = [ |
|
|
reconstruct_lemma( |
|
|
word, |
|
|
self.model.config.vocabulary["lemma_rule"][lemma_rule_id] |
|
|
) |
|
|
for word, lemma_rule_id in zip(words, lemma_rule_ids, strict=True) |
|
|
] |
|
|
|
|
|
if joint_feats_ids: |
|
|
upos, xpos, feats = zip( |
|
|
*[ |
|
|
self.model.config.vocabulary["joint_feats"][joint_feats_id].split('#') |
|
|
for joint_feats_id in joint_feats_ids |
|
|
], |
|
|
strict=True |
|
|
) |
|
|
result["upos"] = list(upos) |
|
|
result["xpos"] = list(xpos) |
|
|
result["feats"] = list(feats) |
|
|
|
|
|
renumerate_and_decode_arcs = lambda arcs, id2rel: [ |
|
|
( |
|
|
|
|
|
|
|
|
|
|
|
ids[arc_from] if arc_from != arc_to else '0', |
|
|
ids[arc_to], |
|
|
id2rel[deprel_id] |
|
|
) |
|
|
for arc_from, arc_to, deprel_id in arcs |
|
|
] |
|
|
if deps_ud: |
|
|
result["deps_ud"] = renumerate_and_decode_arcs( |
|
|
deps_ud, |
|
|
self.model.config.vocabulary["ud_deprel"] |
|
|
) |
|
|
if deps_eud: |
|
|
result["deps_eud"] = renumerate_and_decode_arcs( |
|
|
deps_eud, |
|
|
self.model.config.vocabulary["eud_deprel"] |
|
|
) |
|
|
|
|
|
if misc_ids: |
|
|
result["miscs"] = [ |
|
|
self.model.config.vocabulary["misc"][misc_id] |
|
|
for misc_id in misc_ids |
|
|
] |
|
|
|
|
|
if deepslot_ids: |
|
|
result["deepslots"] = [ |
|
|
self.model.config.vocabulary["deepslot"][deepslot_id] |
|
|
for deepslot_id in deepslot_ids |
|
|
] |
|
|
if semclass_ids: |
|
|
result["semclasses"] = [ |
|
|
self.model.config.vocabulary["semclass"][semclass_id] |
|
|
for semclass_id in semclass_ids |
|
|
] |
|
|
return result |
|
|
|
|
|
@staticmethod |
|
|
def _enumerate_words(words: list[str]) -> list[str]: |
|
|
ids = [] |
|
|
current_id = 0 |
|
|
current_null_count = 0 |
|
|
for word in words: |
|
|
if word == "#NULL": |
|
|
current_null_count += 1 |
|
|
ids.append(f"{current_id}.{current_null_count}") |
|
|
else: |
|
|
current_id += 1 |
|
|
current_null_count = 0 |
|
|
ids.append(f"{current_id}") |
|
|
return ids |
|
|
|
|
|
@staticmethod |
|
|
def _format_as_conllu(sentences: list[dict]) -> str: |
|
|
""" |
|
|
Format a list of sentence dicts into a CoNLL-U formatted string. |
|
|
""" |
|
|
formatted = [] |
|
|
for sentence in sentences: |
|
|
|
|
|
lines = [f"# text = {sentence['text']}"] |
|
|
|
|
|
id2idx = {token_id: idx for idx, token_id in enumerate(sentence['ids'])} |
|
|
|
|
|
|
|
|
heads = [''] * len(id2idx) |
|
|
deprels = [''] * len(id2idx) |
|
|
if "deps_ud" in sentence: |
|
|
for arc_from, arc_to, deprel in sentence['deps_ud']: |
|
|
token_idx = id2idx[arc_to] |
|
|
heads[token_idx] = arc_from |
|
|
deprels[token_idx] = deprel |
|
|
|
|
|
|
|
|
deps_dicts = [{} for _ in range(len(id2idx))] |
|
|
if "deps_eud" in sentence: |
|
|
for arc_from, arc_to, deprel in sentence['deps_eud']: |
|
|
token_idx = id2idx[arc_to] |
|
|
deps_dicts[token_idx][arc_from] = deprel |
|
|
|
|
|
for idx, token_id in enumerate(sentence['ids']): |
|
|
word = sentence['words'][idx] |
|
|
lemma = sentence['lemmas'][idx] if "lemmas" in sentence else '' |
|
|
upos = sentence['upos'][idx] if "upos" in sentence else '' |
|
|
xpos = sentence['xpos'][idx] if "xpos" in sentence else '' |
|
|
feats = sentence['feats'][idx] if "feats" in sentence else '' |
|
|
deps = '|'.join(f"{head}:{rel}" for head, rel in deps_dicts[idx].items()) or '_' |
|
|
misc = sentence['miscs'][idx] if "miscs" in sentence else '' |
|
|
deepslot = sentence['deepslots'][idx] if "deepslots" in sentence else '' |
|
|
semclass = sentence['semclasses'][idx] if "semclasses" in sentence else '' |
|
|
|
|
|
line = '\t'.join([ |
|
|
token_id, word, lemma, upos, xpos, feats, heads[idx], |
|
|
deprels[idx], deps, misc, deepslot, semclass |
|
|
]) |
|
|
lines.append(line) |
|
|
formatted.append('\n'.join(lines)) |
|
|
return '\n\n'.join(formatted) |