| |
|
| | from transformers import Pipeline |
| |
|
| | from src.lemmatize_helper import reconstruct_lemma |
| |
|
| |
|
| | class ConlluTokenClassificationPipeline(Pipeline): |
| | def __init__( |
| | self, |
| | model, |
| | tokenizer: callable = None, |
| | sentenizer: callable = None, |
| | **kwargs |
| | ): |
| | super().__init__(model=model, **kwargs) |
| | self.tokenizer = tokenizer |
| | self.sentenizer = sentenizer |
| |
|
| | |
| | def _sanitize_parameters(self, output_format: str = 'list', **kwargs): |
| | if output_format not in ['list', 'str']: |
| | raise ValueError( |
| | f"output_format must be 'str' or 'list', not {output_format}" |
| | ) |
| | |
| | return {}, {}, {'output_format': output_format} |
| |
|
| | |
| | def preprocess(self, inputs: str) -> dict: |
| | if not isinstance(inputs, str): |
| | raise ValueError("pipeline input must be string (text)") |
| |
|
| | sentences = [sentence for sentence in self.sentenizer(inputs)] |
| | words = [ |
| | [word for word in self.tokenizer(sentence)] |
| | for sentence in sentences |
| | ] |
| | |
| | self._texts = sentences |
| | return {"words": words} |
| |
|
| | |
| | def _forward(self, model_inputs: dict) -> dict: |
| | return self.model(**model_inputs, inference_mode=True) |
| |
|
| | |
| | def postprocess(self, model_outputs: dict, output_format: str) -> list[dict] | str: |
| | sentences = self._decode_model_output(model_outputs) |
| | |
| | if output_format == 'str': |
| | sentences = self._format_as_conllu(sentences) |
| | return sentences |
| |
|
| | def _decode_model_output(self, model_outputs: dict) -> list[dict]: |
| | n_sentences = len(model_outputs["words"]) |
| |
|
| | sentences_decoded = [] |
| | for i in range(n_sentences): |
| |
|
| | def select_arcs(arcs, batch_idx): |
| | |
| | |
| | return arcs[arcs[:, 0] == batch_idx][:, 1:] |
| | |
| | |
| | n_words = len(model_outputs["words"][i]) |
| |
|
| | optional_tags = {} |
| | if "lemma_rules" in model_outputs: |
| | optional_tags["lemma_rule_ids"] = model_outputs["lemma_rules"][i, :n_words].tolist() |
| | if "joint_feats" in model_outputs: |
| | optional_tags["joint_feats_ids"] = model_outputs["joint_feats"][i, :n_words].tolist() |
| | if "deps_ud" in model_outputs: |
| | optional_tags["deps_ud"] = select_arcs(model_outputs["deps_ud"], i).tolist() |
| | if "deps_eud" in model_outputs: |
| | optional_tags["deps_eud"] = select_arcs(model_outputs["deps_eud"], i).tolist() |
| | if "miscs" in model_outputs: |
| | optional_tags["misc_ids"] = model_outputs["miscs"][i, :n_words].tolist() |
| | if "deepslots" in model_outputs: |
| | optional_tags["deepslot_ids"] = model_outputs["deepslots"][i, :n_words].tolist() |
| | if "semclasses" in model_outputs: |
| | optional_tags["semclass_ids"] = model_outputs["semclasses"][i, :n_words].tolist() |
| |
|
| | sentence_decoded = self._decode_sentence( |
| | text=self._texts[i], |
| | words=model_outputs["words"][i], |
| | **optional_tags, |
| | ) |
| | sentences_decoded.append(sentence_decoded) |
| | return sentences_decoded |
| |
|
| | def _decode_sentence( |
| | self, |
| | text: str, |
| | words: list[str], |
| | lemma_rule_ids: list[int] = None, |
| | joint_feats_ids: list[int] = None, |
| | deps_ud: list[list[int]] = None, |
| | deps_eud: list[list[int]] = None, |
| | misc_ids: list[int] = None, |
| | deepslot_ids: list[int] = None, |
| | semclass_ids: list[int] = None |
| | ) -> dict: |
| |
|
| | |
| | ids = self._enumerate_words(words) |
| |
|
| | result = { |
| | "text": text, |
| | "words": words, |
| | "ids": ids |
| | } |
| |
|
| | |
| | if lemma_rule_ids: |
| | result["lemmas"] = [ |
| | reconstruct_lemma( |
| | word, |
| | self.model.config.vocabulary["lemma_rule"][lemma_rule_id] |
| | ) |
| | for word, lemma_rule_id in zip(words, lemma_rule_ids, strict=True) |
| | ] |
| | |
| | if joint_feats_ids: |
| | upos, xpos, feats = zip( |
| | *[ |
| | self.model.config.vocabulary["joint_feats"][joint_feats_id].split('#') |
| | for joint_feats_id in joint_feats_ids |
| | ], |
| | strict=True |
| | ) |
| | result["upos"] = list(upos) |
| | result["xpos"] = list(xpos) |
| | result["feats"] = list(feats) |
| | |
| | renumerate_and_decode_arcs = lambda arcs, id2rel: [ |
| | ( |
| | |
| | |
| | |
| | ids[arc_from] if arc_from != arc_to else '0', |
| | ids[arc_to], |
| | id2rel[deprel_id] |
| | ) |
| | for arc_from, arc_to, deprel_id in arcs |
| | ] |
| | if deps_ud: |
| | result["deps_ud"] = renumerate_and_decode_arcs( |
| | deps_ud, |
| | self.model.config.vocabulary["ud_deprel"] |
| | ) |
| | if deps_eud: |
| | result["deps_eud"] = renumerate_and_decode_arcs( |
| | deps_eud, |
| | self.model.config.vocabulary["eud_deprel"] |
| | ) |
| | |
| | if misc_ids: |
| | result["miscs"] = [ |
| | self.model.config.vocabulary["misc"][misc_id] |
| | for misc_id in misc_ids |
| | ] |
| | |
| | if deepslot_ids: |
| | result["deepslots"] = [ |
| | self.model.config.vocabulary["deepslot"][deepslot_id] |
| | for deepslot_id in deepslot_ids |
| | ] |
| | if semclass_ids: |
| | result["semclasses"] = [ |
| | self.model.config.vocabulary["semclass"][semclass_id] |
| | for semclass_id in semclass_ids |
| | ] |
| | return result |
| |
|
| | @staticmethod |
| | def _enumerate_words(words: list[str]) -> list[str]: |
| | ids = [] |
| | current_id = 0 |
| | current_null_count = 0 |
| | for word in words: |
| | if word == "#NULL": |
| | current_null_count += 1 |
| | ids.append(f"{current_id}.{current_null_count}") |
| | else: |
| | current_id += 1 |
| | current_null_count = 0 |
| | ids.append(f"{current_id}") |
| | return ids |
| |
|
| | @staticmethod |
| | def _format_as_conllu(sentences: list[dict]) -> str: |
| | """ |
| | Format a list of sentence dicts into a CoNLL-U formatted string. |
| | """ |
| | formatted = [] |
| | for sentence in sentences: |
| | |
| | lines = [f"# text = {sentence['text']}"] |
| |
|
| | id2idx = {token_id: idx for idx, token_id in enumerate(sentence['ids'])} |
| |
|
| | |
| | heads = [''] * len(id2idx) |
| | deprels = [''] * len(id2idx) |
| | if "deps_ud" in sentence: |
| | for arc_from, arc_to, deprel in sentence['deps_ud']: |
| | token_idx = id2idx[arc_to] |
| | heads[token_idx] = arc_from |
| | deprels[token_idx] = deprel |
| |
|
| | |
| | deps_dicts = [{} for _ in range(len(id2idx))] |
| | if "deps_eud" in sentence: |
| | for arc_from, arc_to, deprel in sentence['deps_eud']: |
| | token_idx = id2idx[arc_to] |
| | deps_dicts[token_idx][arc_from] = deprel |
| |
|
| | for idx, token_id in enumerate(sentence['ids']): |
| | word = sentence['words'][idx] |
| | lemma = sentence['lemmas'][idx] if "lemmas" in sentence else '' |
| | upos = sentence['upos'][idx] if "upos" in sentence else '' |
| | xpos = sentence['xpos'][idx] if "xpos" in sentence else '' |
| | feats = sentence['feats'][idx] if "feats" in sentence else '' |
| | deps = '|'.join(f"{head}:{rel}" for head, rel in deps_dicts[idx].items()) or '_' |
| | misc = sentence['miscs'][idx] if "miscs" in sentence else '' |
| | deepslot = sentence['deepslots'][idx] if "deepslots" in sentence else '' |
| | semclass = sentence['semclasses'][idx] if "semclasses" in sentence else '' |
| | |
| | line = '\t'.join([ |
| | token_id, word, lemma, upos, xpos, feats, heads[idx], |
| | deprels[idx], deps, misc, deepslot, semclass |
| | ]) |
| | lines.append(line) |
| | formatted.append('\n'.join(lines)) |
| | return '\n\n'.join(formatted) |