| | import re |
| | from abc import ABC, abstractmethod |
| | from typing import Any, Union |
| |
|
| | import inflect |
| | import nltk |
| | from flair.data import Sentence |
| | from flair.models import SequenceTagger |
| |
|
| | __all__ = [ |
| | "DropFileExtensions", |
| | "DropNonAlpha", |
| | "DropShortWords", |
| | "DropSpecialCharacters", |
| | "DropTokens", |
| | "DropURLs", |
| | "DropWords", |
| | "FilterPOS", |
| | "FrequencyMinWordCount", |
| | "FrequencyTopK", |
| | "ReplaceSeparators", |
| | "ToLowercase", |
| | "ToSingular", |
| | ] |
| |
|
| |
|
| | class BaseTextTransform(ABC): |
| | """Base class for string transforms.""" |
| |
|
| | @abstractmethod |
| | def __call__(self, text: str): |
| | raise NotImplementedError |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}()" |
| |
|
| |
|
| | class DropFileExtensions(BaseTextTransform): |
| | """Remove file extensions from the input text.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove file extensions from. |
| | """ |
| | text = re.sub(r"\.\w+", "", text) |
| |
|
| | return text |
| |
|
| |
|
| | class DropNonAlpha(BaseTextTransform): |
| | """Remove non-alpha words from the input text.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove non-alpha words from. |
| | """ |
| | text = re.sub(r"[^a-zA-Z\s]", "", text) |
| |
|
| | return text |
| |
|
| |
|
| | class DropShortWords(BaseTextTransform): |
| | """Remove short words from the input text. |
| | |
| | Args: |
| | min_length (int): Minimum length of words to keep. |
| | """ |
| |
|
| | def __init__(self, min_length) -> None: |
| | super().__init__() |
| | self.min_length = min_length |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove short words from. |
| | """ |
| | text = " ".join([word for word in text.split() if len(word) >= self.min_length]) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(min_length={self.min_length})" |
| |
|
| |
|
| | class DropSpecialCharacters(BaseTextTransform): |
| | """Remove special characters from the input text. |
| | |
| | Special characters are defined as any character that is not a word character, whitespace, |
| | hyphen, period, apostrophe, or ampersand. |
| | """ |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove special characters from. |
| | """ |
| | text = re.sub(r"[^\w\s\-\.\'\&]", "", text) |
| |
|
| | return text |
| |
|
| |
|
| | class DropTokens(BaseTextTransform): |
| | """Remove tokens from the input text. |
| | |
| | Tokens are defined as strings enclosed in angle brackets, e.g. <token>. |
| | """ |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove tokens from. |
| | """ |
| | text = re.sub(r"<[^>]+>", "", text) |
| |
|
| | return text |
| |
|
| |
|
| | class DropURLs(BaseTextTransform): |
| | """Remove URLs from the input text.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove URLs from. |
| | """ |
| | text = re.sub(r"http\S+", "", text) |
| |
|
| | return text |
| |
|
| |
|
| | class DropWords(BaseTextTransform): |
| | """Remove words from the input text. |
| | |
| | It is case-insensitive and supports singular and plural forms of the words. |
| | """ |
| |
|
| | def __init__(self, words: list[str]) -> None: |
| | super().__init__() |
| | self.words = words |
| | self.pattern = r"\b(?:{})\b".format("|".join(words)) |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove words from. |
| | """ |
| | text = re.sub(self.pattern, "", text, flags=re.IGNORECASE) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(pattern={self.pattern})" |
| |
|
| |
|
| | class FilterPOS(BaseTextTransform): |
| | """Filter words by POS tags. |
| | |
| | Args: |
| | tags (list): List of POS tags to remove. |
| | engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk". |
| | keep_compound_nouns (bool): Whether to keep composed words. Defaults to True. |
| | """ |
| |
|
| | def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None: |
| | super().__init__() |
| | self.tags = tags |
| | self.engine = engine |
| | self.keep_compound_nouns = keep_compound_nouns |
| |
|
| | if engine == "nltk": |
| | nltk.download("averaged_perceptron_tagger", quiet=True) |
| | nltk.download("punkt", quiet=True) |
| | self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x)) |
| | elif engine == "flair": |
| | self.tagger = SequenceTagger.load("flair/pos-english-fast").predict |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove words with specific POS tags from. |
| | """ |
| | if self.engine == "nltk": |
| | word_tags = self.tagger(text) |
| | text = " ".join([word for word, tag in word_tags if tag not in self.tags]) |
| | elif self.engine == "flair": |
| | sentence = Sentence(text) |
| | self.tagger(sentence) |
| | text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags]) |
| |
|
| | if self.keep_compound_nouns: |
| | compound_nouns = [] |
| |
|
| | if self.engine == "nltk": |
| | for i in range(len(word_tags) - 1): |
| | if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN": |
| | |
| | if word_tags[i][0] == word_tags[i + 1][0]: |
| | continue |
| |
|
| | compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0] |
| | compound_nouns.append(compound_noun) |
| | elif self.engine == "flair": |
| | for i in range(len(sentence.tokens) - 1): |
| | if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN": |
| | |
| | if sentence.tokens[i].text == sentence.tokens[i + 1].text: |
| | continue |
| |
|
| | compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text |
| | compound_nouns.append(compound_noun) |
| |
|
| | text = " ".join([text, " ".join(compound_nouns)]) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})" |
| |
|
| |
|
| | class FrequencyMinWordCount(BaseTextTransform): |
| | """Keep only words that occur more than a minimum number of times in the input text. |
| | |
| | If the threshold is too strong and no words pass the threshold, the threshold is reduced to |
| | the most frequent word. |
| | |
| | Args: |
| | min_count (int): Minimum number of occurrences of a word to keep. |
| | """ |
| |
|
| | def __init__(self, min_count) -> None: |
| | super().__init__() |
| | self.min_count = min_count |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove infrequent words from. |
| | """ |
| | if self.min_count <= 1: |
| | return text |
| |
|
| | words = text.split() |
| | word_counts = {word: words.count(word) for word in words} |
| |
|
| | |
| | max_word_count = max(word_counts.values() or [0]) |
| | min_count = max_word_count if self.min_count > max_word_count else self.min_count |
| |
|
| | text = " ".join([word for word in words if word_counts[word] >= min_count]) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(min_count={self.min_count})" |
| |
|
| |
|
| | class FrequencyTopK(BaseTextTransform): |
| | """Keep only the top k most frequent words in the input text. |
| | |
| | In case of a tie, all words with the same count as the last word are kept. |
| | |
| | Args: |
| | top_k (int): Number of top words to keep. |
| | """ |
| |
|
| | def __init__(self, top_k: int) -> None: |
| | super().__init__() |
| | self.top_k = top_k |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove infrequent words from. |
| | """ |
| | if self.top_k < 1: |
| | return text |
| |
|
| | words = text.split() |
| | word_counts = {word: words.count(word) for word in words} |
| | top_words = sorted(word_counts, key=word_counts.get, reverse=True) |
| |
|
| | |
| | top_words = top_words[: self.top_k] |
| | top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]] |
| |
|
| | text = " ".join([word for word in words if word in top_words]) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}(top_k={self.top_k})" |
| |
|
| |
|
| | class ReplaceSeparators(BaseTextTransform): |
| | """Replace underscores and dashes with spaces.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to replace separators in. |
| | """ |
| | text = re.sub(r"[_\-]", " ", text) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}()" |
| |
|
| |
|
| | class RemoveDuplicates(BaseTextTransform): |
| | """Remove duplicate words from the input text.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to remove duplicate words from. |
| | """ |
| | text = " ".join(list(set(text.split()))) |
| |
|
| | return text |
| |
|
| |
|
| | class TextCompose: |
| | """Compose several transforms together. |
| | |
| | It differs from the torchvision.transforms.Compose class in that it applies the transforms to |
| | a string instead of a PIL Image or Tensor. In addition, it automatically join the list of |
| | input strings into a single string and splits the output string into a list of words. |
| | |
| | Args: |
| | transforms (list): List of transforms to compose. |
| | """ |
| |
|
| | def __init__(self, transforms: list[BaseTextTransform]) -> None: |
| | self.transforms = transforms |
| |
|
| | def __call__(self, text: Union[str, list[str]]) -> Any: |
| | if isinstance(text, list): |
| | text = " ".join(text) |
| |
|
| | for t in self.transforms: |
| | text = t(text) |
| | return text.split() |
| |
|
| | def __repr__(self) -> str: |
| | format_string = self.__class__.__name__ + "(" |
| | for t in self.transforms: |
| | format_string += "\n" |
| | format_string += f" {t}" |
| | format_string += "\n)" |
| | return format_string |
| |
|
| |
|
| | class ToLowercase(BaseTextTransform): |
| | """Convert text to lowercase.""" |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to convert to lowercase. |
| | """ |
| | text = text.lower() |
| |
|
| | return text |
| |
|
| |
|
| | class ToSingular(BaseTextTransform): |
| | """Convert plural words to singular form.""" |
| |
|
| | def __init__(self) -> None: |
| | super().__init__() |
| | self.transform = inflect.engine().singular_noun |
| |
|
| | def __call__(self, text: str): |
| | """ |
| | Args: |
| | text (str): Text to convert to singular form. |
| | """ |
| | words = text.split() |
| | for i, word in enumerate(words): |
| | if not word.endswith("s"): |
| | continue |
| |
|
| | if word[-2:] in ["ss", "us", "is"]: |
| | continue |
| |
|
| | if word[-3:] in ["ies", "oes"]: |
| | continue |
| |
|
| | words[i] = self.transform(word) or word |
| |
|
| | text = " ".join(words) |
| |
|
| | return text |
| |
|
| | def __repr__(self) -> str: |
| | return f"{self.__class__.__name__}()" |
| |
|
| |
|
| | def default_vocabulary_transforms() -> TextCompose: |
| | """Preprocess input text with preprocessing transforms.""" |
| | words_to_drop = [ |
| | "image", |
| | "photo", |
| | "picture", |
| | "thumbnail", |
| | "logo", |
| | "symbol", |
| | "clipart", |
| | "portrait", |
| | "painting", |
| | "illustration", |
| | "icon", |
| | "profile", |
| | ] |
| | pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"] |
| |
|
| | transforms = [] |
| | transforms.append(DropTokens()) |
| | transforms.append(DropURLs()) |
| | transforms.append(DropSpecialCharacters()) |
| | transforms.append(DropFileExtensions()) |
| | transforms.append(ReplaceSeparators()) |
| | transforms.append(DropShortWords(min_length=3)) |
| | transforms.append(DropNonAlpha()) |
| | transforms.append(ToLowercase()) |
| | transforms.append(ToSingular()) |
| | transforms.append(DropWords(words=words_to_drop)) |
| | transforms.append(FrequencyMinWordCount(min_count=2)) |
| | transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False)) |
| | transforms.append(RemoveDuplicates()) |
| |
|
| | transforms = TextCompose(transforms) |
| |
|
| | return transforms |
| |
|