Spaces:
Configuration error
Configuration error
| import re | |
| import urllib | |
| from collections import namedtuple | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Generator, List, Union, Tuple | |
| from loguru import logger | |
| FORMATTING_SEQUENCES = {"*", "**", "***", "_", "__", "~~", "||"} | |
| CODE_BLOCK_SEQUENCES = {"`", "``", "```"} | |
| ALL_SEQUENCES = FORMATTING_SEQUENCES | CODE_BLOCK_SEQUENCES | |
| MAX_FORMATTING_SEQUENCE_LENGTH = max(len(seq) for seq in ALL_SEQUENCES) | |
| class SplitCandidates(Enum): | |
| SPACE = 1 | |
| NEWLINE = 2 | |
| LAST_CHAR = 3 | |
| SPLIT_CANDIDATES_PREFRENCE = [ | |
| SplitCandidates.NEWLINE, | |
| SplitCandidates.SPACE, | |
| SplitCandidates.LAST_CHAR, | |
| ] | |
| BLOCK_SPLIT_CANDIDATES = [r"\n#\s+", r"\n##\s+", r"\n###\s+"] | |
| CODE_BLOCK_LEVEL = 10 | |
| MarkdownChunk = namedtuple("MarkdownChunk", "string level") | |
| class SplitCandidateInfo: | |
| last_seen: int | |
| active_sequences: List[str] | |
| active_sequences_length: int | |
| def __init__(self): | |
| self.last_seen = None | |
| self.active_sequences = [] | |
| self.active_sequences_length = 0 | |
| def process_sequence(self, seq: str, is_in_code_block: bool): | |
| if is_in_code_block: | |
| if self.active_sequences and seq == self.active_sequences[-1]: | |
| last_seq = self.active_sequences.pop() | |
| self.active_sequences_length -= len(last_seq) | |
| return True | |
| elif seq in CODE_BLOCK_SEQUENCES: | |
| self.active_sequences.append(seq) | |
| self.active_sequences_length += len(seq) | |
| return True | |
| else: | |
| for k in range(len(self.active_sequences) - 1, -1, -1): | |
| if seq == self.active_sequences[k]: | |
| sequences_being_removed = self.active_sequences[k:] | |
| self.active_sequences = self.active_sequences[:k] | |
| self.active_sequences_length -= sum( | |
| len(seq) for seq in sequences_being_removed | |
| ) | |
| return False | |
| self.active_sequences.append(seq) | |
| self.active_sequences_length += len(seq) | |
| return False | |
| def copy_from(self, other): | |
| self.last_seen = other.last_seen | |
| self.active_sequences = other.active_sequences.copy() | |
| self.active_sequences_length = other.active_sequences_length | |
| def physical_split(markdown: str, max_chunk_size: int) -> Generator[str, None, None]: | |
| if max_chunk_size <= MAX_FORMATTING_SEQUENCE_LENGTH: | |
| raise ValueError( | |
| f"max_chunk_size must be greater than {MAX_FORMATTING_SEQUENCE_LENGTH}" | |
| ) | |
| split_candidates = { | |
| SplitCandidates.SPACE: SplitCandidateInfo(), | |
| SplitCandidates.NEWLINE: SplitCandidateInfo(), | |
| SplitCandidates.LAST_CHAR: SplitCandidateInfo(), | |
| } | |
| is_in_code_block = False | |
| chunk_start_from, chunk_char_count, chunk_prefix = 0, 0, "" | |
| def split_chunk(): | |
| for split_variant in SPLIT_CANDIDATES_PREFRENCE: | |
| split_candidate = split_candidates[split_variant] | |
| if split_candidate.last_seen is None: | |
| continue | |
| chunk_end = split_candidate.last_seen + ( | |
| 1 if split_variant == SplitCandidates.LAST_CHAR else 0 | |
| ) | |
| chunk = ( | |
| chunk_prefix | |
| + markdown[chunk_start_from:chunk_end] | |
| + "".join(reversed(split_candidate.active_sequences)) | |
| ) | |
| next_chunk_prefix = "".join(split_candidate.active_sequences) | |
| next_chunk_char_count = len(next_chunk_prefix) | |
| next_chunk_start_from = chunk_end + ( | |
| 0 if split_variant == SplitCandidates.LAST_CHAR else 1 | |
| ) | |
| split_candidates[SplitCandidates.NEWLINE] = SplitCandidateInfo() | |
| split_candidates[SplitCandidates.SPACE] = SplitCandidateInfo() | |
| return ( | |
| chunk, | |
| next_chunk_start_from, | |
| next_chunk_char_count, | |
| next_chunk_prefix, | |
| ) | |
| i = 0 | |
| while i < len(markdown): | |
| for j in range(MAX_FORMATTING_SEQUENCE_LENGTH, 0, -1): | |
| seq = markdown[i: i + j] | |
| if seq in ALL_SEQUENCES: | |
| last_char_split_candidate_len = ( | |
| chunk_char_count | |
| + split_candidates[ | |
| SplitCandidates.LAST_CHAR | |
| ].active_sequences_length | |
| + len(seq) | |
| ) | |
| if last_char_split_candidate_len >= max_chunk_size: | |
| ( | |
| next_chunk, | |
| chunk_start_from, | |
| chunk_char_count, | |
| chunk_prefix, | |
| ) = split_chunk() | |
| yield next_chunk | |
| is_in_code_block = split_candidates[ | |
| SplitCandidates.LAST_CHAR | |
| ].process_sequence(seq, is_in_code_block) | |
| i += len(seq) | |
| chunk_char_count += len(seq) | |
| split_candidates[SplitCandidates.LAST_CHAR].last_seen = i - 1 | |
| break | |
| if i >= len(markdown): | |
| break | |
| split_candidates[SplitCandidates.LAST_CHAR].last_seen = i | |
| chunk_char_count += 1 | |
| if markdown[i] == "\n": | |
| split_candidates[SplitCandidates.NEWLINE].copy_from( | |
| split_candidates[SplitCandidates.LAST_CHAR] | |
| ) | |
| elif markdown[i] == " ": | |
| split_candidates[SplitCandidates.SPACE].copy_from( | |
| split_candidates[SplitCandidates.LAST_CHAR] | |
| ) | |
| last_char_split_candidate_len = ( | |
| chunk_char_count | |
| + split_candidates[SplitCandidates.LAST_CHAR].active_sequences_length | |
| ) | |
| if last_char_split_candidate_len == max_chunk_size: | |
| next_chunk, chunk_start_from, chunk_char_count, chunk_prefix = split_chunk() | |
| yield next_chunk | |
| i += 1 | |
| if chunk_start_from < len(markdown): | |
| yield chunk_prefix + markdown[chunk_start_from:] | |
| def get_logical_blocks_recursively( | |
| markdown: str, max_chunk_size: int, all_sections: list, split_candidate_index=0 | |
| ) -> List[MarkdownChunk]: | |
| if split_candidate_index >= len(BLOCK_SPLIT_CANDIDATES): | |
| for chunk in physical_split(markdown, max_chunk_size): | |
| all_sections.append( | |
| MarkdownChunk(string=chunk, level=split_candidate_index) | |
| ) | |
| return all_sections | |
| chunks = [] | |
| add_index = 0 | |
| for add_index, split_candidate in enumerate( | |
| BLOCK_SPLIT_CANDIDATES[split_candidate_index:] | |
| ): | |
| chunks = re.split(split_candidate, markdown) | |
| if len(chunks) > 1: | |
| break | |
| for i, chunk in enumerate(chunks): | |
| level = split_candidate_index + add_index | |
| if i > 0: | |
| level += 1 | |
| prefix = "\n\n" + "#" * level + " " | |
| if not chunk.strip(): | |
| continue | |
| if len(chunk) <= max_chunk_size: | |
| all_sections.append(MarkdownChunk(string=prefix + chunk, level=level - 1)) | |
| else: | |
| get_logical_blocks_recursively( | |
| chunk, | |
| max_chunk_size, | |
| all_sections, | |
| split_candidate_index=split_candidate_index + add_index + 1, | |
| ) | |
| return all_sections | |
| def markdown_splitter( | |
| path: Union[str, Path], max_chunk_size: int, **additional_splitter_settings | |
| ) -> List[dict]: | |
| try: | |
| with open(path, "r") as f: | |
| markdown = f.read() | |
| except OSError: | |
| return [] | |
| if len(markdown) < max_chunk_size: | |
| return [{"text": markdown, "metadata": {"heading": ""}}] | |
| sections = [MarkdownChunk(string="", level=0)] | |
| markdown, additional_metadata = preprocess_markdown( | |
| markdown, additional_splitter_settings | |
| ) | |
| # Split by code and non-code | |
| chunks = markdown.split("```") | |
| for i, chunk in enumerate(chunks): | |
| if i % 2 == 0: # Every even element (0 indexed) is a non-code | |
| logical_blocks = get_logical_blocks_recursively( | |
| chunk, max_chunk_size=max_chunk_size, all_sections=[] | |
| ) | |
| sections += logical_blocks | |
| else: # Process the code section | |
| rows = chunk.split("\n") | |
| code = rows[1:] | |
| lang = rows[0] # Get the language name | |
| # Provide a hint to LLM | |
| all_code_rows = ( | |
| [ | |
| f"\nFollowing is a code section in {lang}, delimited by triple backticks:", | |
| f"```{lang}", | |
| ] | |
| + code | |
| + ["```"] | |
| ) | |
| all_code_str = "\n".join(all_code_rows) | |
| # Merge code to a previous logical block if there is enough space | |
| if len(sections[-1].string) + len(all_code_str) < max_chunk_size: | |
| sections[-1] = MarkdownChunk( | |
| string=sections[-1].string + all_code_str, level=sections[-1].level | |
| ) | |
| # If code block is larger than max size, physically split it | |
| elif len(all_code_str) >= max_chunk_size: | |
| code_chunks = physical_split( | |
| all_code_str, max_chunk_size=max_chunk_size | |
| ) | |
| for cchunk in code_chunks: | |
| # Assign language header to the code chunk, if doesn't exist | |
| if f"```{lang}" not in cchunk: | |
| cchunk_rows = cchunk.split("```") | |
| cchunk = f"```{lang}\n" + cchunk_rows[1] + "```" | |
| sections.append( | |
| MarkdownChunk(string=cchunk, level=CODE_BLOCK_LEVEL) | |
| ) | |
| # Otherwise, add as a single chunk | |
| else: | |
| sections.append( | |
| MarkdownChunk(string=all_code_str, level=CODE_BLOCK_LEVEL) | |
| ) | |
| all_out = postprocess_sections( | |
| sections, | |
| max_chunk_size, | |
| additional_splitter_settings, | |
| additional_metadata, | |
| path, | |
| ) | |
| return all_out | |
| def preprocess_markdown(markdown: str, additional_settings: dict) -> Tuple[str, dict]: | |
| preprocess_remove_images = additional_settings.get("remove_images", False) | |
| preprocess_remove_extra_newlines = additional_settings.get( | |
| "remove_extra_newlines", True | |
| ) | |
| preprocess_find_metadata = additional_settings.get("find_metadata", dict()) | |
| if preprocess_remove_images: | |
| markdown = remove_images(markdown) | |
| if preprocess_remove_extra_newlines: | |
| markdown = remove_extra_newlines(markdown) | |
| additional_metadata = {} | |
| if preprocess_find_metadata: | |
| if not isinstance(preprocess_find_metadata, dict): | |
| raise TypeError( | |
| f"find_metadata settings should be of type dict. Got {type(preprocess_find_metadata)}" | |
| ) | |
| for label, search_string in preprocess_find_metadata.items(): | |
| logger.info(f"Looking for metadata: {search_string}") | |
| metadata = find_metadata(markdown, search_string) | |
| if metadata: | |
| logger.info(f"\tFound metadata for {label} - {metadata}") | |
| additional_metadata[label] = metadata | |
| return markdown, additional_metadata | |
| def postprocess_sections( | |
| sections: List[MarkdownChunk], | |
| max_chunk_size: int, | |
| additional_settings: dict, | |
| additional_metadata: dict, | |
| path: Union[str, Path], | |
| ) -> List[dict]: | |
| all_out = [] | |
| skip_first = additional_settings.get("skip_first", False) | |
| merge_headers = additional_settings.get("merge_sections", False) | |
| # Remove all empty sections | |
| sections = [s for s in sections if s.string] | |
| if sections and skip_first: | |
| # remove first section | |
| sections = sections[1:] | |
| if sections and merge_headers: | |
| # Merge sections | |
| sections = merge_sections(sections, max_chunk_size=max_chunk_size) | |
| current_heading = "" | |
| sections_metadata = {"Document name": Path(path).name} | |
| for s in sections: | |
| stripped_string = s.string.strip() | |
| doc_metadata = {} | |
| if len(stripped_string) > 0: | |
| heading = "" | |
| if stripped_string.startswith("#"): # heading detected | |
| heading = stripped_string.split("\n")[0].replace("#", "").strip() | |
| stripped_heading = heading.replace("#", "").replace(" ", "").strip() | |
| if not stripped_heading: | |
| heading = "" | |
| if s.level == 0: | |
| current_heading = heading | |
| doc_metadata["heading"] = urllib.parse.quote( | |
| heading | |
| ) # isolate the heading | |
| else: | |
| doc_metadata["heading"] = "" | |
| final_section = add_section_metadata( | |
| stripped_string, | |
| section_metadata={ | |
| **sections_metadata, | |
| **{"Subsection of": current_heading}, | |
| **additional_metadata, | |
| }, | |
| ) | |
| all_out.append({"text": final_section, "metadata": doc_metadata}) | |
| return all_out | |
| def remove_images(page_md: str) -> str: | |
| return re.sub(r"""!\[[^\]]*\]\((.*?)\s*("(?:.*[^"])")?\s*\)""", "", page_md) | |
| def remove_extra_newlines(page_md) -> str: | |
| page_md = re.sub(r"\n{3,}", "\n\n", page_md) | |
| return page_md | |
| def add_section_metadata(s, section_metadata: dict): | |
| metadata_s = "" | |
| for k, v in section_metadata.items(): | |
| if v: | |
| metadata_s += f"{k}: {v}\n" | |
| metadata = f"Metadata applicable to the next chunk of text delimited by five stars:\n>> METADATA START\n{metadata_s}>> METADATA END\n\n" | |
| return metadata + "*****\n" + s + "\n*****" | |
| def find_metadata(page_md: str, search_string: str) -> str: | |
| pattern = rf"{search_string}(.*)" | |
| match = re.search(pattern, page_md) | |
| if match: | |
| return match.group(1) | |
| return "" | |
| def merge_sections( | |
| sections: List[MarkdownChunk], max_chunk_size: int | |
| ) -> List[MarkdownChunk]: | |
| current_section = sections[0] | |
| all_out = [] | |
| prev_level = 0 | |
| for s in sections[1:]: | |
| if ( | |
| len(current_section.string + s.string) > max_chunk_size | |
| or s.level <= prev_level | |
| ): | |
| all_out.append(current_section) | |
| current_section = s | |
| prev_level = 0 | |
| else: | |
| current_section = MarkdownChunk( | |
| string=current_section.string + s.string, level=current_section.level | |
| ) | |
| prev_level = s.level if s.level != CODE_BLOCK_LEVEL else prev_level | |
| all_out.append(current_section) | |
| return all_out | |