from dataclasses import dataclass from typing import Callable, List from transformers import AutoTokenizer @dataclass class Tokenizer: chunk_overlap: int tokens_per_chunk: int decode: Callable[[List[int]], str] encode: Callable[[str], List[int]] def split_text_on_tokens(text: str, tokenizer: Tokenizer) -> List[str]: """Split incoming text and return chunks using tokenizer.""" splits: list[str] = [] input_ids = tokenizer.encode(text)[1:] start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-plus") tokenizer = Tokenizer( chunk_overlap=50, tokens_per_chunk=500, decode=tokenizer.decode, encode=tokenizer.encode, )