|
|
import random |
|
|
import time |
|
|
|
|
|
|
|
|
def truncate_document(topic_model, doc_length, tokenizer, document: str): |
|
|
""" Truncate a document to a certain length |
|
|
|
|
|
If you want to add a custom tokenizer, then it will need to have a `decode` and |
|
|
`encode` method. An example would be the following custom tokenizer: |
|
|
|
|
|
```python |
|
|
class Tokenizer: |
|
|
'A custom tokenizer that splits on commas' |
|
|
def encode(self, doc): |
|
|
return doc.split(",") |
|
|
|
|
|
def decode(self, doc_chuncks): |
|
|
return ",".join(doc_chuncks) |
|
|
``` |
|
|
|
|
|
You can use this tokenizer by passing it to the `tokenizer` parameter. |
|
|
|
|
|
Arguments: |
|
|
topic_model: A BERTopic model |
|
|
doc_length: The maximum length of each document. If a document is longer, |
|
|
it will be truncated. If None, the entire document is passed. |
|
|
tokenizer: The tokenizer used to calculate to split the document into segments |
|
|
used to count the length of a document. |
|
|
* If tokenizer is 'char', then the document is split up |
|
|
into characters which are counted to adhere to `doc_length` |
|
|
* If tokenizer is 'whitespace', the document is split up |
|
|
into words separated by whitespaces. These words are counted |
|
|
and truncated depending on `doc_length` |
|
|
* If tokenizer is 'vectorizer', then the internal CountVectorizer |
|
|
is used to tokenize the document. These tokens are counted |
|
|
and trunctated depending on `doc_length`. They are decoded with |
|
|
whitespaces. |
|
|
* If tokenizer is a callable, then that callable is used to tokenize |
|
|
the document. These tokens are counted and truncated depending |
|
|
on `doc_length` |
|
|
document: A single document |
|
|
|
|
|
Returns: |
|
|
truncated_document: A truncated document |
|
|
""" |
|
|
if doc_length is not None: |
|
|
if tokenizer == "char": |
|
|
truncated_document = document[:doc_length] |
|
|
elif tokenizer == "whitespace": |
|
|
truncated_document = " ".join(document.split()[:doc_length]) |
|
|
elif tokenizer == "vectorizer": |
|
|
tokenizer = topic_model.vectorizer_model.build_tokenizer() |
|
|
truncated_document = " ".join(tokenizer(document)[:doc_length]) |
|
|
elif hasattr(tokenizer, 'encode') and hasattr(tokenizer, 'decode'): |
|
|
encoded_document = tokenizer.encode(document) |
|
|
truncated_document = tokenizer.decode(encoded_document[:doc_length]) |
|
|
return truncated_document |
|
|
return document |
|
|
|
|
|
|
|
|
def retry_with_exponential_backoff( |
|
|
func, |
|
|
initial_delay: float = 1, |
|
|
exponential_base: float = 2, |
|
|
jitter: bool = True, |
|
|
max_retries: int = 10, |
|
|
errors: tuple = None, |
|
|
): |
|
|
"""Retry a function with exponential backoff.""" |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
|
|
num_retries = 0 |
|
|
delay = initial_delay |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
|
except errors as e: |
|
|
|
|
|
num_retries += 1 |
|
|
|
|
|
|
|
|
if num_retries > max_retries: |
|
|
raise Exception( |
|
|
f"Maximum number of retries ({max_retries}) exceeded." |
|
|
) |
|
|
|
|
|
|
|
|
delay *= exponential_base * (1 + jitter * random.random()) |
|
|
|
|
|
|
|
|
time.sleep(delay) |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
raise e |
|
|
|
|
|
return wrapper |