Spaces:
Running
Running
| from itertools import chain | |
| import random | |
| from typing import List, Optional, Tuple, Union | |
| from tokenizers import AddedToken | |
| from transformers import ByT5Tokenizer | |
| import numpy as np | |
| import torch | |
| from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET | |
| def text_to_utf16_numbers(text): | |
| utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling | |
| numbers = [] | |
| # Iterate through each pair of bytes and combine them into a single number | |
| for i in range(0, len(utf16_bytes), 2): | |
| # Combine two adjacent bytes into a single number | |
| number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) | |
| numbers.append(number) | |
| return numbers | |
| def utf16_numbers_to_text(numbers): | |
| byte_array = bytearray() | |
| for number in numbers: | |
| # Extract the two bytes from the number and add them to the byte array | |
| byte_array.append(number & 0xFF) # Lower byte | |
| byte_array.append((number >> 8) & 0xFF) # Upper byte | |
| text = byte_array.decode('utf-16le', errors="ignore") | |
| return text | |
| def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True): | |
| tokens = text_to_utf16_numbers(text) | |
| tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens | |
| lang_list = [] | |
| if langs: | |
| for lang in langs: | |
| code = LANGUAGE_MAP[lang] | |
| lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) | |
| tokens = lang_list + tokens | |
| if add_bos: | |
| tokens.insert(0, eos_token_id) | |
| return tokens, lang_list | |
| class Byt5LangTokenizer(ByT5Tokenizer): | |
| def __init__(self, | |
| eos_token="</s>", | |
| unk_token="<unk>", | |
| pad_token="<pad>", | |
| model_max_length=None, | |
| **kwargs, | |
| ): | |
| self.pad_token = pad_token | |
| self.eos_token = eos_token | |
| self.unk_token = unk_token | |
| self.bos_token = eos_token | |
| self.offset = TOKEN_OFFSET | |
| self.pad_id = 0 | |
| self.eos_id = 1 | |
| self.unk_id = 2 | |
| self.model_max_length = model_max_length | |
| self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS | |
| super().__init__() | |
| def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs): | |
| tokenized = [] | |
| all_langs = [] | |
| is_list = True | |
| # Convert to list of lists format | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| is_list = False | |
| if langs is None: | |
| langs = [None] * len(texts) | |
| if isinstance(langs[0], str): | |
| langs = [langs] | |
| assert len(langs) == len(texts) | |
| for text, lang in zip(texts, langs): | |
| tokens, lang_list = _tokenize(text, lang) | |
| tokenized.append(tokens) | |
| all_langs.append(lang_list) | |
| # Convert back to flat format | |
| if not is_list: | |
| tokenized = tokenized[0] | |
| all_langs = all_langs[0] | |
| return {"input_ids": tokenized, "langs": all_langs} | |
| def decode( | |
| self, | |
| token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], | |
| skip_special_tokens: bool = False, | |
| clean_up_tokenization_spaces: bool = None, | |
| **kwargs, | |
| ) -> str: | |
| if isinstance(token_ids, (np.ndarray, torch.Tensor)): | |
| token_ids = token_ids.tolist() | |
| token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start] | |
| token_ids = [t - TOKEN_OFFSET for t in token_ids] | |
| text = utf16_numbers_to_text(token_ids) | |
| return text | |