File size: 5,006 Bytes
565e754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Literal
from transformers import AutoTokenizer
from langchain_text_splitters import RecursiveCharacterTextSplitter, NLTKTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
class Splitter:
"""
Класс описывает функционал разделения текста на чанки тремя способами на выбор:
- рекурсивно разбивая чанки различными разделителями
в порядке возрастания "жесткости" их эффекта;
- объединяя выделенные с помощью библиотеки NLTK предложения
в чанки определенного размера и с наложением;
- разбивая текст на семантически связанные блоки
с помощью векторных представлений текстов;
"""
def __init__(
self,
mode: Literal["recursive", "nltk", "semantic"],
model_name: str = "deepvk/USER-bge-m3",
chunk_size: int = 256,
chunk_overlap: int = 64,
**splitter_kwargs,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
match mode:
case "recursive":
self.splitter = RecursiveCharacterTextSplitter(
separators=[
"\n### ", "\n## ", "\n# ",
"\n\n", "\n",
"!", "?", ". ", ";", ",", ")", " ", "",
],
keep_separator="end",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=lambda x: len(self.tokenizer.encode(x, add_special_tokens=False)),
**splitter_kwargs,
)
self.split_fn = self._recursive_split
case "nltk":
self.splitter = NLTKTextSplitter(
language="russian",
**splitter_kwargs,
)
self.split_fn = self._nltk_split
case "semantic":
self.splitter = SemanticChunker(
HuggingFaceEmbeddings(
model_name=model_name,
encode_kwargs={"normalize_embeddings": True},
),
**splitter_kwargs,
)
self.split_fn = self._semantic_split
def split_text(self, text: str) -> list[str]:
"""
Доступная пользователю функция разделения текста на чанки
"""
return self.split_fn(text)
def _recursive_split(self, text: str) -> list[str]:
"""
Функция разделения текста на чанки при self.splitter == RecursiveCharacterTextSplitter
"""
return [
chunk
for chunk in self.splitter.split_text(text)
if any(ch.isalpha() for ch in set(chunk))
]
def _nltk_split(self, text: str) -> list[str]:
"""
Функция разделения текста на чанки при self.splitter == NLTKTextSplitter
"""
sentences = self.splitter.split_text(text)[0].split("\n\n")
sent_sizes = [
len(self.tokenizer.encode(sent, add_special_tokens=False))
for sent in sentences
]
chunks = []
i, n = 0, len(sentences)
while i < n:
cur_len, cur_texts = 0, []
# --- Собираем строки в чанк ---
j = i
while (j < n) and (cur_len + sent_sizes[j] <= self.chunk_size):
cur_texts.append(sentences[j])
cur_len += sent_sizes[j]
j += 1
chunks.append(cur_texts)
# --- Сдвигаем окно с overlap ---
if j >= n:
break
# Держим overlap в токенах, но не превышая его
overlap_len, k = 0, j - 1
while (k >= i) and (overlap_len + sent_sizes[k] <= self.chunk_overlap):
overlap_len += sent_sizes[k]
k -= 1 # идём назад от конца чанка
# Следующий старт = k+1
i = k + 1
return chunks
def _semantic_split(self, text: str) -> list[str]:
"""
Функция разделения текста на чанки при self.splitter == SemanticChunker
"""
return self.splitter.split_text(text)
|