Image_generation / chunking.py
manasdhir's picture
minor changes
efa9d0b
# import re
# from typing import List
# def word_count(text: str) -> int:
# return len(text.strip().split())
# def split_paragraphs(text: str) -> List[str]:
# """
# Split text by double newlines assuming paragraphs separated by blank lines.
# """
# paras = [p.strip() for p in text.split('\n\n') if p.strip()]
# return paras
# def split_content_into_batches(
# content: str,
# max_words: int = 2000,
# batch_size: int = 16,
# overlap_words: int = 100 # number of words to overlap between chunk boundaries
# ):
# """
# Splits content by # (level 1), then groups level 2 subsections (##) into chunks with max_words,
# overlaps last subsection of previous chunk into next chunk, and batches chunks into batch_size.
# Prints the word count of final chunks and how many chunks were split by each level.
# Returns:
# List of batches, each is a list of chunk strings.
# """
# level_1_pattern = re.compile(r'^# (.+)', re.MULTILINE)
# level_2_pattern = re.compile(r'^## (.+)', re.MULTILINE)
# level_1_indexes = [(m.start(), m.end(), m.group(1)) for m in level_1_pattern.finditer(content)]
# ends = [pos[0] for pos in level_1_indexes[1:]] + [len(content)]
# chunks = []
# # Counters for debugging
# count_level_1_chunks = 0
# count_level_2_chunks = 0 # will count grouped chunks here
# count_paragraph_chunks = 0
# def get_word_slice(text: str, word_limit: int, from_end=False) -> str:
# """
# Utility to return first/last `word_limit` words of text.
# """
# words = text.strip().split()
# if from_end:
# return " ".join(words[-word_limit:])
# else:
# return " ".join(words[:word_limit])
# for i, (start, heading_end, heading_text) in enumerate(level_1_indexes):
# section_start = heading_end
# section_end = ends[i]
# section_text = content[section_start:section_end].strip()
# section_wc = word_count(section_text)
# if section_wc <= max_words:
# # Whole section fits in one chunk, no grouping needed
# final_chunk = f"# {heading_text}\n{section_text}"
# chunks.append(final_chunk)
# print(f"Final chunk word count: {word_count(final_chunk)}")
# count_level_1_chunks += 1
# continue
# # Get all level 2 subsections within this level 1 section
# level_2_indexes = [(m.start(), m.end(), m.group(1)) for m in level_2_pattern.finditer(section_text)]
# if not level_2_indexes:
# # No subsections, fallback: split by paragraphs
# paragraphs = split_paragraphs(section_text)
# para_chunks = []
# current_chunk = ""
# for para in paragraphs:
# if word_count(current_chunk) + word_count(para) + 1 <= max_words:
# current_chunk += "\n\n" + para if current_chunk else para
# else:
# para_chunks.append(current_chunk)
# current_chunk = para
# if current_chunk:
# para_chunks.append(current_chunk)
# for pc in para_chunks:
# final_chunk = f"# {heading_text}\n{pc}"
# chunks.append(final_chunk)
# print(f"Final chunk word count: {word_count(final_chunk)}")
# count_paragraph_chunks += 1
# continue
# # Otherwise, group multiple level 2 subsections until max_words limit reached
# level_2_ends = [pos[0] for pos in level_2_indexes[1:]] + [len(section_text)]
# grouped_chunks = []
# current_chunk_subsections = []
# current_chunk_word_count = 0
# for j, (l2_start, l2_end, l2_heading_text) in enumerate(level_2_indexes):
# subsec_start = l2_end
# subsec_end = level_2_ends[j]
# subsec_text = section_text[subsec_start:subsec_end].strip()
# subsec_wc = word_count(subsec_text)
# # Prepare full subsection text (heading + content)
# full_subsec_text = f"## {l2_heading_text}\n{subsec_text}"
# # If adding this subsection exceeds limit, yield current chunk and start new chunk
# if current_chunk_word_count + subsec_wc > max_words and current_chunk_subsections:
# # Form chunk text concatenating all subsections added so far
# chunk_text = "\n\n".join(current_chunk_subsections)
# # Prepend level 1 heading
# chunk_text = f"# {heading_text}\n{chunk_text}"
# grouped_chunks.append(chunk_text)
# print(f"Final chunk word count: {word_count(chunk_text)}")
# count_level_2_chunks += 1
# # Prepare overlap: repeat last subsection (or last part of it) in next chunk
# overlap_text = current_chunk_subsections[-1]
# overlap_words_text = get_word_slice(overlap_text, overlap_words, from_end=True)
# overlap_chunk_text = f"## {l2_heading_text}\n{overlap_words_text}"
# # Start new chunk with overlap
# current_chunk_subsections = [overlap_chunk_text]
# current_chunk_word_count = word_count(overlap_words_text)
# # Add current subsection fresh after overlap
# current_chunk_subsections.append(full_subsec_text)
# current_chunk_word_count += subsec_wc
# else:
# # Add subsection to current chunk normally
# current_chunk_subsections.append(full_subsec_text)
# current_chunk_word_count += subsec_wc
# # Add the last chunk if any
# if current_chunk_subsections:
# chunk_text = "\n\n".join(current_chunk_subsections)
# chunk_text = f"# {heading_text}\n{chunk_text}"
# grouped_chunks.append(chunk_text)
# print(f"Final chunk word count: {word_count(chunk_text)}")
# count_level_2_chunks += 1
# chunks.extend(grouped_chunks)
# batches = [chunks[i:i + batch_size] for i in range(0, len(chunks), batch_size)]
# print(f"Chunks split by level 1 headings (#): {count_level_1_chunks}")
# print(f"Chunks split by grouped level 2 headings (##): {count_level_2_chunks}")
# print(f"Chunks split by paragraphs: {count_paragraph_chunks}")
# return batches
import re
from typing import List
def word_count(text: str) -> int:
return len(text.strip().split())
def split_paragraphs(text: str) -> List[str]:
"""
Split text by double newlines assuming paragraphs separated by blank lines.
"""
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
return paras
def split_to_fit(text: str, max_words: int) -> List[str]:
"""
Ensure text is broken into chunks each <= max_words by paragraph first, then sliding window.
"""
words = text.strip().split()
if len(words) <= max_words:
return [" ".join(words)]
# Try paragraph-level
paras = split_paragraphs(text)
chunks = []
current = []
for para in paras:
para_words = para.strip().split()
if sum(len(p.split()) for p in current) + len(para_words) <= max_words:
current.append(para)
else:
if current:
chunks.append(" ".join(current))
# if single paragraph too big, fallback to word sliding
if len(para_words) > max_words:
step = max_words
for i in range(0, len(para_words), step):
part = " ".join(para_words[i : i + max_words])
chunks.append(part)
current = []
else:
current = [para]
if current:
chunks.append(" ".join(current))
# As a final fallback, split any remaining oversize
final = []
for c in chunks:
if word_count(c) <= max_words:
final.append(c)
else:
wc_words = c.strip().split()
step = max_words
for i in range(0, len(wc_words), step):
final.append(" ".join(wc_words[i : i + max_words]))
return final
def split_content_into_batches(
content: str,
max_words: int = 2000,
batch_size: int = 16,
overlap_words: int = 200, # number of words to overlap between chunk boundaries
):
"""
Splits content by # (level 1), then groups level 2 subsections (##) into chunks with max_words,
overlaps last subsection of previous chunk into next chunk, and batches chunks into batch_size.
Enforces strict cap: no chunk exceeds max_words.
Prints final word count of each chunk.
"""
level_1_pattern = re.compile(r"^# (.+)", re.MULTILINE)
level_2_pattern = re.compile(r"^## (.+)", re.MULTILINE)
level_1_indexes = [(m.start(), m.end(), m.group(1)) for m in level_1_pattern.finditer(content)]
ends = [pos[0] for pos in level_1_indexes[1:]] + [len(content)]
chunks = []
# Counters for debugging
count_level_1_chunks = 0
count_level_2_chunks = 0
count_paragraph_chunks = 0
def get_word_slice(text: str, word_limit: int, from_end=False) -> str:
words = text.strip().split()
if from_end:
return " ".join(words[-word_limit:])
else:
return " ".join(words[:word_limit])
for i, (start, heading_end, heading_text) in enumerate(level_1_indexes):
section_start = heading_end
section_end = ends[i]
section_text = content[section_start:section_end].strip()
section_wc = word_count(section_text)
if section_wc <= max_words:
final_chunk = f"# {heading_text}\n{section_text}"
chunks.append(final_chunk)
count_level_1_chunks += 1
continue
level_2_indexes = [(m.start(), m.end(), m.group(1)) for m in level_2_pattern.finditer(section_text)]
if not level_2_indexes:
# Fallback: split by paragraphs, enforcing cap
paragraphs = split_paragraphs(section_text)
para_chunks = []
current_chunk = ""
for para in paragraphs:
candidate = f"{current_chunk}\n\n{para}" if current_chunk else para
if word_count(candidate) <= max_words:
current_chunk = candidate
else:
if current_chunk:
para_chunks.extend(split_to_fit(current_chunk, max_words))
current_chunk = para
if current_chunk:
para_chunks.extend(split_to_fit(current_chunk, max_words))
for pc in para_chunks:
final_chunk = f"# {heading_text}\n{pc}"
chunks.append(final_chunk)
count_paragraph_chunks += 1
continue
level_2_ends = [pos[0] for pos in level_2_indexes[1:]] + [len(section_text)]
current_chunk_subsections = []
current_chunk_word_count = 0
for j, (l2_start, l2_end, l2_heading_text) in enumerate(level_2_indexes):
subsec_start = l2_end
subsec_end = level_2_ends[j]
raw_subsec_text = section_text[subsec_start:subsec_end].strip()
# ensure subsections themselves don't violate cap
subsec_pieces = split_to_fit(raw_subsec_text, max_words)
for piece in subsec_pieces:
piece_wc = word_count(piece)
heading_with_piece = f"## {l2_heading_text}\n{piece}"
if current_chunk_word_count + piece_wc > max_words and current_chunk_subsections:
chunk_text = "\n\n".join(current_chunk_subsections)
chunk_text = f"# {heading_text}\n{chunk_text}"
for final_piece in split_to_fit(chunk_text, max_words):
chunks.append(final_piece)
count_level_2_chunks += 1
# overlap: take last subsection (or piece) and include tail
overlap_text = current_chunk_subsections[-1]
overlap_words_text = get_word_slice(overlap_text, overlap_words, from_end=True)
current_chunk_subsections = [overlap_words_text]
current_chunk_word_count = word_count(overlap_words_text)
current_chunk_subsections.append(heading_with_piece)
current_chunk_word_count += piece_wc
else:
current_chunk_subsections.append(heading_with_piece)
current_chunk_word_count += piece_wc
if current_chunk_subsections:
chunk_text = "\n\n".join(current_chunk_subsections)
chunk_text = f"# {heading_text}\n{chunk_text}"
for final_piece in split_to_fit(chunk_text, max_words):
chunks.append(final_piece)
count_level_2_chunks += 1
# Final batching
batches = [chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)]
# Print final word counts per chunk
for bi, batch in enumerate(batches, start=1):
print(f"\nBatch {bi}:")
for ci, chunk in enumerate(batch, start=1):
wc = word_count(chunk)
print(f" Chunk {ci} word count: {wc}")
print(f"\nSummary:")
print(f" Chunks split by level 1 headings (#): {count_level_1_chunks}")
print(f" Chunks split by grouped level 2 headings (##): {count_level_2_chunks}")
print(f" Chunks split by paragraphs: {count_paragraph_chunks}")
return batches