File size: 5,134 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
import codecs
import re
import typing as tp
from functools import lru_cache
import spacy
import torch
from sacremoses import MosesDetokenizer, MosesPunctNormalizer
from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
from stopes.utils.language_codes import language_code_to_short_code
def remove_emojis(text: str) -> str:
emoji_pattern = re.compile(
"["
"\U0001f600-\U0001f64f" # emoticons
"\U0001f300-\U0001f5ff" # symbols & pictographs
"\U0001f680-\U0001f6ff" # transport & map symbols
"\U0001f1e0-\U0001f1ff" # flags (iOS)
"\U00002702-\U000027b0"
"\U000024c2-\U0001f251"
"\U0001f900-\U0001f9ff" # Supplemental Symbols and Pictographs
"\U0001f700-\U0001f77f" # Alchemical Symbols
"\U0001f780-\U0001f7ff" # Geometric Shapes Extended
"\U0001f800-\U0001f8ff" # Supplemental Arrows-C
"\U0001fa00-\U0001fa6f" # Chess Symbols
"\U0001fa70-\U0001faff" # Symbols and Pictographs Extended-A
"\U0001f6c0-\U0001f6cf" # Miscellaneous Symbols and Pictographs (part)
"\U0001f6d0-\U0001f6d5" # Miscellaneous Symbols and Pictographs (part)
"\U0001f6f0-\U0001f6fa" # Miscellaneous Symbols and Pictographs (part)
"]+",
flags=re.UNICODE,
)
return emoji_pattern.sub(r"", text)
def batched(inputs: tp.Iterable, batch_size=10000) -> tp.Iterable:
batch = []
for line in inputs:
batch.append(line)
if len(batch) == batch_size:
yield batch
batch = []
yield batch
def filter_empty_string(text):
return not any(char.isalnum() for char in text)
def remove_non_printable_chars(string):
return re.sub(r"[^\x20-\x7E]", "", string)
def deescape_special_chars(string):
return codecs.decode(string, "unicode_escape")
def resplit(text: str, max_length: int, sep: str) -> tp.List[str]:
words = text.split(sep)
result = []
current_piece = ""
for i, word in enumerate(words[:-1]):
# Append separator back to each word except the last
word += sep
if len(current_piece) + len(word) <= max_length:
current_piece += word
else:
if current_piece:
result.append(current_piece)
current_piece = word
# Handle the last word separately to avoid adding an extra separator
last_word = words[-1]
if len(current_piece) + len(last_word) <= max_length:
current_piece += last_word
else:
if current_piece:
result.append(current_piece)
current_piece = last_word
if current_piece:
result.append(current_piece)
return result
@lru_cache
def get_moses_normalizers(lang):
moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
mpn = MosesPunctNormalizer(lang=moses_lang)
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
md = MosesDetokenizer(lang=moses_lang)
return mpn, md
@lru_cache
def get_splitter(lang: str, model_name: str = None):
moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
if model_name is None:
model_name = (
f"{moses_lang}_core_web_sm"
if moses_lang == "en"
else f"{moses_lang}_core_news_sm"
)
try:
if torch.cuda.is_available():
spacy.require_gpu()
spacy_nlp = spacy.load(model_name, enable=["sentencizer"])
spacy_nlp.add_pipe("sentencizer")
def spacy_splitter(text):
for batch in batched(text, batch_size=999_000):
for sent in spacy_nlp("".join(batch)).sents:
yield str(sent)
return spacy_splitter
except ModuleNotFoundError:
print(
f"Spacy splitter not found for {lang}, switching to stopes implementation"
)
return get_split_algo(lang[:3], "default")
class ResplitSentenceSplitter:
def __init__(
self,
fallback_separators=(".", "!", "?", "...", "\n", ";", ",", ":", ">", " "),
):
self.fallback_separators = fallback_separators
def __call__(
self, document: str, lang: str = "eng_Latn", max_length: int = 200
) -> tp.List[str]:
mpn, md = get_moses_normalizers(lang)
# XXX: two below are not various language friendly
# document = deescape_special_chars(document)
# document = remove_non_printable_chars(document)
document = remove_emojis(document)
raw_sentences = get_splitter(lang)(document)
for separator in self.fallback_separators or []:
raw_sentences = [
subchunk.strip()
for sent in raw_sentences
for subchunk in resplit(sent, max_length=max_length, sep=separator)
]
return [
mpn.normalize(md.detokenize(sent.strip().split()))
for sent in raw_sentences
if len(sent) > 1 and not filter_empty_string(sent)
]
|