Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import os | |
| import torch | |
| import string | |
| import onnxruntime as ort | |
| from dataclasses import dataclass | |
| from omegaconf import OmegaConf | |
| from typing import List, Optional, Union, Dict | |
| from sentencepiece import SentencePieceProcessor | |
| from torch.utils.data import Dataset, DataLoader | |
| from typing import Iterator, List, Iterable, Tuple | |
| ACRONYM_TOKEN = "<ACRONYM>" | |
| torch.set_grad_enabled(False) | |
| torch.backends.cudnn.enabled = False | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| class PunctCapConfigONNX: | |
| spe_filename: str = "xlm_roberta_encoding.model" | |
| model_filename: str = "nemo_model.onnx" | |
| config_filename: str = "config.yaml" | |
| directory: Optional[str] = None | |
| class PunctCapModelONNX: | |
| def __init__(self, cfg: PunctCapConfigONNX): | |
| self._spe_path = os.path.join(cfg.directory, cfg.spe_filename) | |
| onnx_path = os.path.join(cfg.directory, cfg.model_filename) | |
| config_path = os.path.join(cfg.directory, cfg.config_filename) | |
| self._tokenizer: SentencePieceProcessor = SentencePieceProcessor(self._spe_path) | |
| self._ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path) | |
| self._config = OmegaConf.load(config_path) | |
| self._max_len = self._config.max_length | |
| self._pre_labels: List[str] = self._config.pre_labels | |
| self._post_labels: List[str] = self._config.post_labels | |
| self._languages: List[str] = self._config.languages | |
| self._null_token = self._config.get("null_token", "<NULL>") | |
| def _setup_dataloader(self, texts: List[str], batch_size_tokens: int, overlap: int) -> DataLoader: | |
| dataset: TextInferenceDataset = TextInferenceDataset( | |
| texts=texts, | |
| batch_size_tokens=batch_size_tokens, | |
| overlap=overlap, | |
| max_length=self._max_len, | |
| spe_model_path=self._spe_path, | |
| ) | |
| return DataLoader( | |
| dataset=dataset, | |
| collate_fn=dataset.collate_fn, | |
| batch_sampler=dataset.sampler, | |
| ) | |
| def punctuation_removal(self, texts: List[str]) -> List[str]: | |
| punkt = string.punctuation + """`÷×؛<>_()*&^%][ـ،/:"؟.,'{}~¦+|!”…–ـ""" + """!?。。""" | |
| punkt = punkt.replace("-", "") | |
| punkt = punkt.replace("'", "") | |
| punkt += "„“" | |
| return [text.translate(str.maketrans("", "", punkt)).lower().strip() for text in texts] | |
| def infer( | |
| self, | |
| texts: List[str], | |
| apply_sbd: bool = False, | |
| batch_size_tokens: int = 4096, | |
| overlap: int = 16, | |
| ) -> Union[List[str], List[List[str]]]: | |
| texts = self.punctuation_removal(texts) | |
| collectors: List[PunctCapCollector] = [ | |
| PunctCapCollector(sp_model=self._tokenizer, apply_sbd=apply_sbd, overlap=overlap) | |
| for _ in range(len(texts)) | |
| ] | |
| dataloader: DataLoader = self._setup_dataloader(texts=texts, batch_size_tokens=batch_size_tokens, overlap=overlap) | |
| for batch in dataloader: | |
| input_ids, batch_indices, input_indices, lengths = batch | |
| pre_preds, post_preds, cap_preds, seg_preds = self._ort_session.run(None, {"input_ids": input_ids.numpy()}) | |
| batch_size = input_ids.shape[0] | |
| for i in range(batch_size): | |
| length = lengths[i].item() | |
| batch_idx = batch_indices[i].item() | |
| input_idx = input_indices[i].item() | |
| segment_ids = input_ids[i, 1 : length - 1].tolist() | |
| segment_pre_preds = pre_preds[i, 1 : length - 1].tolist() | |
| segment_post_preds = post_preds[i, 1 : length - 1].tolist() | |
| segment_cap_preds = cap_preds[i, 1 : length - 1].tolist() | |
| segment_sbd_preds = seg_preds[i, 1 : length - 1].tolist() | |
| pre_tokens = [self._pre_labels[i] for i in segment_pre_preds] | |
| post_tokens = [self._post_labels[i] for i in segment_post_preds] | |
| pre_tokens = [x if x != self._null_token else None for x in pre_tokens] | |
| post_tokens = [x if x != self._null_token else None for x in post_tokens] | |
| collectors[batch_idx].collect( | |
| ids=segment_ids, | |
| pre_preds=pre_tokens, | |
| post_preds=post_tokens, | |
| cap_preds=segment_cap_preds, | |
| sbd_preds=segment_sbd_preds, | |
| idx=input_idx, | |
| ) | |
| outputs: Union[List[str], List[List[str]]] = [x.produce() for x in collectors] | |
| return outputs | |
| class TokenizedSegment: | |
| input_ids: List[int] | |
| batch_idx: int | |
| input_idx: int | |
| def __len__(self) -> int: | |
| return len(self.input_ids) | |
| class TokenBatchSampler(Iterable): | |
| def __init__(self, segments: List[TokenizedSegment], batch_size_tokens: int): | |
| self._batches = self._make_batches(segments, batch_size_tokens) | |
| def _make_batches(self, segments: List[TokenizedSegment], batch_size_tokens: int) -> List[List[int]]: | |
| segments_with_index = [(segment, i) for i, segment in enumerate(segments)] | |
| segments_with_index.sort(key=lambda x: len(x[0]), reverse=True) | |
| batches, current_batch_elements, current_max_len = [], [], 0 | |
| for segment, idx in segments_with_index: | |
| potential_max_len = max(current_max_len, len(segment)) | |
| if potential_max_len * (len(current_batch_elements) + 1) > batch_size_tokens: | |
| batches.append(current_batch_elements) | |
| current_batch_elements, current_max_len = [], 0 | |
| current_batch_elements.append(idx) | |
| current_max_len = potential_max_len | |
| if current_batch_elements: | |
| batches.append(current_batch_elements) | |
| return batches | |
| def __iter__(self) -> Iterator: | |
| yield from self._batches | |
| def __len__(self) -> int: | |
| return len(self._batches) | |
| class TextInferenceDataset(Dataset): | |
| def __init__( | |
| self, | |
| texts: List[str], | |
| spe_model_path: str, | |
| batch_size_tokens: int = 4096, | |
| max_length: int = 512, | |
| overlap: int = 32, | |
| ): | |
| self._spe_model = SentencePieceProcessor(spe_model_path) | |
| self._segments = self._tokenize_inputs(texts, max_length, overlap) | |
| self._sampler = TokenBatchSampler(self._segments, batch_size_tokens) | |
| def sampler(self) -> Iterable: | |
| return self._sampler | |
| def _tokenize_inputs(self, texts: List[str], max_len: int, overlap: int) -> List[TokenizedSegment]: | |
| max_len -= 2 | |
| segments = [] | |
| for batch_idx, text in enumerate(texts): | |
| ids, start, input_idx = self._spe_model.EncodeAsIds(text), 0, 0 | |
| while start < len(ids): | |
| adjusted_start = start - overlap if input_idx else 0 | |
| segments.append( | |
| TokenizedSegment( | |
| ids[adjusted_start : adjusted_start + max_len], | |
| batch_idx, | |
| input_idx, | |
| ) | |
| ) | |
| start += max_len - overlap | |
| input_idx += 1 | |
| return segments | |
| def __len__(self) -> int: | |
| return len(self._segments) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int, int]: | |
| segment = self._segments[idx] | |
| input_ids = torch.Tensor([self._spe_model.bos_id(), *segment.input_ids, self._spe_model.eos_id()]) | |
| return input_ids, segment.batch_idx, segment.input_idx | |
| def collate_fn(self, batch: List[Tuple[torch.Tensor, int, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| input_ids = [x[0] for x in batch] | |
| lengths = torch.tensor([x.shape[0] for x in input_ids]) | |
| max_len = lengths.max().item() | |
| batched_ids = torch.full((len(input_ids), max_len), self._spe_model.pad_id()) | |
| for idx, ids in enumerate(input_ids): | |
| batched_ids[idx, : lengths[idx]] = ids | |
| return ( | |
| batched_ids, | |
| torch.tensor([x[1] for x in batch]), | |
| torch.tensor([x[2] for x in batch]), | |
| lengths, | |
| ) | |
| class PCSegment: | |
| ids: List[int] | |
| pre_preds: List[Optional[str]] | |
| post_preds: List[Optional[str]] | |
| cap_preds: List[List[int]] | |
| sbd_preds: List[int] | |
| def __len__(self): | |
| return len(self.ids) | |
| class PunctCapCollector: | |
| def __init__(self, apply_sbd: bool, overlap: int, sp_model: SentencePieceProcessor): | |
| self._segments: Dict[int, PCSegment] = {} | |
| self._apply_sbd = apply_sbd | |
| self._overlap = overlap | |
| self._sp_model = sp_model | |
| def collect( | |
| self, | |
| ids: List[int], | |
| pre_preds: List[Optional[str]], | |
| post_preds: List[Optional[str]], | |
| sbd_preds: List[int], | |
| cap_preds: List[List[int]], | |
| idx: int, | |
| ): | |
| self._segments[idx] = PCSegment( | |
| ids=ids, | |
| pre_preds=pre_preds, | |
| post_preds=post_preds, | |
| sbd_preds=sbd_preds, | |
| cap_preds=cap_preds, | |
| ) | |
| def produce(self) -> Union[List[str], str]: | |
| ids: List[int] = [] | |
| pre_preds: List[Optional[str]] = [] | |
| post_preds: List[Optional[str]] = [] | |
| cap_preds: List[List[int]] = [] | |
| sbd_preds: List[int] = [] | |
| for i in range(len(self._segments)): | |
| segment = self._segments[i] | |
| start = 0 | |
| stop = len(segment) | |
| if i > 0: | |
| start += self._overlap // 2 | |
| if i < len(self._segments) - 1: | |
| stop -= self._overlap // 2 | |
| ids.extend(segment.ids[start:stop]) | |
| pre_preds.extend(segment.pre_preds[start:stop]) | |
| post_preds.extend(segment.post_preds[start:stop]) | |
| sbd_preds.extend(segment.sbd_preds[start:stop]) | |
| cap_preds.extend(segment.cap_preds[start:stop]) | |
| input_tokens = [self._sp_model.IdToPiece(x) for x in ids] | |
| output_texts: List[str] = [] | |
| current_chars: List[str] = [] | |
| for token_idx, token in enumerate(input_tokens): | |
| if token.startswith("▁") and current_chars: | |
| current_chars.append(" ") | |
| char_start = 1 if token.startswith("▁") else 0 | |
| for token_char_idx, char in enumerate(token[char_start:], start=char_start): | |
| if token_char_idx == char_start and pre_preds[token_idx] is not None: | |
| current_chars.append(pre_preds[token_idx]) | |
| if cap_preds[token_idx][token_char_idx]: | |
| char = char.upper() | |
| current_chars.append(char) | |
| label = post_preds[token_idx] | |
| if label == ACRONYM_TOKEN: | |
| current_chars.append(".") | |
| elif token_char_idx == len(token) - 1 and post_preds[token_idx] is not None: | |
| current_chars.append(post_preds[token_idx]) | |
| if self._apply_sbd and token_char_idx == len(token) - 1 and sbd_preds[token_idx]: | |
| output_texts.append("".join(current_chars)) | |
| current_chars = [] | |
| if current_chars: | |
| output_texts.append("".join(current_chars)) | |
| if not self._apply_sbd: | |
| if len(output_texts) > 1: | |
| raise ValueError(f"Not applying SBD but got more than one result: {output_texts}") | |
| return output_texts[0] | |
| return output_texts | |
| class MultiLingual: | |
| def __init__(self): | |
| cfg = PunctCapConfigONNX(directory="/code/models/multilingual") | |
| self._punctuator = PunctCapModelONNX(cfg) | |
| def punctuate(self, data: str) -> str: | |
| return self._punctuator.infer([data])[0] | |