rag-app / shared_utilities.py
bhavinmatariya's picture
Upload 13 files
3506c42 verified
import os
import time
import math
import re
import tiktoken
import logging
import asyncio
from openai import AsyncOpenAI
from dotenv import load_dotenv
load_dotenv()
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def chunk_text(
text: str, chunk_size: int = 8000, overlap: int = 400, max_chunks: int | None = None, use_tokens: bool = True) -> list[str]:
"""
Robust chunker with safety guards and fallback.
- If tiktoken is available and use_tokens=True: chunk by tokens.
- Otherwise: chunk by characters.
- Ensures forward progress even with bad params.
- For small files, reduces overlap to avoid duplicate content.
"""
if not text or not isinstance(text, str):
return []
# 1) Sanitize control chars that sometimes appear in OCR/Aspose output
# Keep \n and \t; remove other C0 control chars.
text = re.sub(r"[^\S\r\n]+", " ", text) # collapse runs of spaces
text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F]", "", text).strip()
# 2) Guard parameters
if chunk_size <= 0:
chunk_size = 1000
if overlap < 0:
overlap = 0
if overlap >= chunk_size:
# auto-fix: keep 20% overlap
overlap = max(0, int(chunk_size * 0.2))
# 3) For very small files, return as single chunk to avoid duplicates
text_length = len(text)
if text_length < chunk_size: # If file is smaller than one chunk
logger.info(f"Very small file detected ({text_length} chars), returning as single chunk")
return [text]
# 4) For small files, reduce overlap to avoid duplicate content
if text_length < chunk_size * 2: # If file is smaller than 2 chunks
overlap = min(overlap, max(0, int(text_length * 0.1))) # Reduce overlap to 10% of file size
logger.info(f"Small file detected ({text_length} chars), reduced overlap to {overlap}")
# 3) Try token-based chunking
tokens = None
encoding = None
token_mode = False
t0 = time.time()
if use_tokens:
try:
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
token_mode = True
except Exception:
token_mode = False # fall back to char mode
chunks: list[str] = []
if token_mode:
# Safety: compute theoretical max chunk count to avoid infinite loops
step = chunk_size - overlap
if step <= 0:
step = max(1, chunk_size // 2) # should not happen due to guard above
theoretical = math.ceil(max(1, len(tokens)) / step)
hard_cap = min(theoretical + 2, 20000) # absolute safety cap
if max_chunks is None:
max_chunks = hard_cap
else:
max_chunks = min(max_chunks, hard_cap)
start = 0
made_progress = True
count = 0
while start < len(tokens) and count < max_chunks and made_progress:
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = encoding.decode(chunk_tokens).strip()
if chunk_text:
chunks.append(chunk_text)
prev_start = start
start = end - overlap
# Guarantee forward progress
if start <= prev_start:
start = prev_start + 1
count += 1
else:
# Character-based fallback
step = chunk_size - overlap
if step <= 0:
step = max(1, chunk_size // 2)
theoretical = math.ceil(max(1, len(text)) / step)
hard_cap = min(theoretical + 2, 20000)
if max_chunks is None:
max_chunks = hard_cap
else:
max_chunks = min(max_chunks, hard_cap)
start = 0
count = 0
while start < len(text) and count < max_chunks:
end = min(start + chunk_size, len(text))
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
prev_start = start
start = end - overlap
if start <= prev_start:
start = prev_start + 1
count += 1
t1 = time.time()
return chunks
def validate_chunk_sizes(chunks, max_tokens=8000):
"""Validate that all chunks are within token limits"""
try:
encoding = tiktoken.get_encoding("cl100k_base")
valid_chunks = []
for i, chunk in enumerate(chunks):
tokens = encoding.encode(chunk)
if len(tokens) <= max_tokens:
valid_chunks.append(chunk)
else:
logger.warning(f"Chunk {i+1} exceeds {max_tokens} tokens ({len(tokens)} tokens). Splitting...")
# Split oversized chunk
split_chunks = chunk_text(chunk, chunk_size=max_tokens, overlap=200)
valid_chunks.extend(split_chunks)
return valid_chunks
except Exception as e:
logger.error(f"Error validating chunk sizes: {e}")
return chunks # Return original chunks if validation fails
async def generate_embeddings_batch(texts, progress_callback=None, batch_size=100) -> list:
try:
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
# Show progress message
if progress_callback:
await progress_callback("Generating embeddings...")
all_embeddings = []
total_batches = math.ceil(len(texts) / batch_size)
for batch_idx in range(0, len(texts), batch_size):
batch_texts = texts[batch_idx:batch_idx + batch_size]
current_batch = (batch_idx // batch_size) + 1
if progress_callback:
await progress_callback(f"Generating embeddings batch {current_batch}/{total_batches}...")
try:
response = await client.embeddings.create(
model="text-embedding-3-small",
input=batch_texts
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
# Small delay to avoid rate limiting
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error generating batch {current_batch} embeddings: {e}")
# If batch fails, try individual texts
for text in batch_texts:
try:
response = await client.embeddings.create(
model="text-embedding-3-small",
input=[text]
)
all_embeddings.append(response.data[0].embedding)
except Exception as individual_error:
logger.error(f"Error generating individual embedding: {individual_error}")
# Add zero vector as fallback
all_embeddings.append([0.0] * 1536) # text-embedding-3-small has 1536 dimensions
return all_embeddings
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise Exception(f"Error generating batch embeddings: {str(e)}")
def extract_visual_elements_from_text(document_text: str) -> dict:
"""
Extract Visual Elements sections from document text with page context.
Enhanced to extract all figures with their page numbers.
Args:
document_text: Full document text content
Returns:
Dict mapping figure identifiers to dict with: description, page_number, figure_number
Format: {"figure_5": {"description": "...", "page_number": 2, "figure_number": 5}}
"""
try:
visual_elements = {}
if not document_text:
return visual_elements
# Split document by page markers to get page context
pages = re.split(r'---\s*PAGE\s*(\d+)\s*---', document_text)
current_page = None
# Process pages: pages[0] is content before first page marker, then alternating page_num and content
for i in range(len(pages)):
if i % 2 == 1: # Odd indices are page numbers
current_page = int(pages[i])
elif i > 0 and current_page is not None: # Even indices (after first) are page content
page_content = pages[i]
# Extract figures from this page content
page_figures = _extract_figures_from_page_content(page_content, current_page)
for fig_key, fig_data in page_figures.items():
if fig_key not in visual_elements: # Don't overwrite if already found
visual_elements[fig_key] = fig_data
# Also check content before first page marker (assume page 1)
if pages and len(pages) > 0:
pre_content = pages[0]
pre_figures = _extract_figures_from_page_content(pre_content, 1)
for fig_key, fig_data in pre_figures.items():
if fig_key not in visual_elements:
visual_elements[fig_key] = fig_data
# Also search entire document for any missed figures (fallback)
_extract_figures_from_full_document(document_text, visual_elements)
logger.info(f"Total visual elements extracted: {len(visual_elements)}")
return visual_elements
except Exception as e:
logger.error(f"Error extracting visual elements from text: {e}")
return {}
def _extract_figures_from_page_content(page_content: str, page_number: int) -> dict:
"""
Extract figures from a single page's content.
Args:
page_content: Text content of a single page
page_number: Page number
Returns:
Dict mapping figure keys to figure data
"""
figures = {}
# Pattern 1: Find "Visual Elements" sections
visual_sections_patterns = [
r'\*\*Visual Elements.*?\*\*:(.*?)(?=\n--- PAGE|\n\*\*[A-Z]|$)',
r'Visual Elements[:\-](.*?)(?=\n--- PAGE|\n\*\*[A-Z]|$)',
r'\*\*Visual.*?Elements.*?\*\*:(.*?)(?=\n--- PAGE|\n\*\*[A-Z]|$)'
]
visual_sections = []
for pattern in visual_sections_patterns:
matches = re.findall(pattern, page_content, re.DOTALL | re.IGNORECASE)
visual_sections.extend(matches)
# If no visual sections, search entire page content
if not visual_sections:
visual_sections = [page_content]
for section in visual_sections:
# Enhanced patterns to extract full figure descriptions
figure_patterns = [
r'\*\*Figure\s*(\d+):\s*([^\*]+?)(?=\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
r'図\s*(\d+)[:\-]\s*([^\*\n]+?)(?=図|\*\*|$)',
r'Figure\s*(\d+)\s*[:\-]\s*([^\*\n]+?)(?=Figure|\*\*|$)',
r'\*\*Figure\s*(\d+)[:\-]\s*\*\*([^\*]+?)(?=\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
# Pattern for numbered list format: "1. **Figure 5:** ..."
r'(?:^|\n)\s*\d+\.\s*\*\*Figure\s*(\d+)[:\-]\s*\*\*([^\*]+?)(?=\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
# Pattern for bullet points: "* **Figure 5:** ..."
r'(?:^|\n)\s*\*\s*\*\*Figure\s*(\d+)[:\-]\s*\*\*([^\*]+?)(?=\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
# Pattern for abbreviated format: "Fig. 1", "Fig. 2", etc. (with markdown)
r'\*\*Fig\.\s*(\d+)[:\-]\s*\*\*([^\*]+?)(?=\*\*Fig|\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
r'\*\*Fig\s*(\d+)[:\-]\s*\*\*([^\*]+?)(?=\*\*Fig|\*\*Figure|\*\*Table|\*\*Chart|\*\*|$)',
# Pattern for abbreviated format: "Fig. 1:", "Fig. 2 -", "Fig.1 Title", "Fig. 5 Title" etc. (without markdown)
r'Fig\.\s*(\d+)\s*[:\-]?\s*([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))',
r'Fig\s*(\d+)\s*[:\-]?\s*([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))',
# Pattern for "Fig.1" (no space after period, no colon)
r'Fig\.(\d+)\s+([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))'
]
for pattern in figure_patterns:
matches = re.findall(pattern, section, re.DOTALL | re.IGNORECASE | re.MULTILINE)
for fig_num, description in matches:
fig_num_int = int(fig_num)
fig_key = f"figure_{fig_num_int}"
clean_description = re.sub(r'\n\s*\n+', '\n', description.strip())
clean_description = re.sub(r' +', ' ', clean_description)
if fig_key not in figures:
figures[fig_key] = {
"description": clean_description,
"page_number": page_number,
"figure_number": fig_num_int,
"title": ""
}
logger.info(f"Extracted figure {fig_num_int} from page {page_number}")
# Also look for direct figure references in page text
direct_figure_patterns = [
r'(Figure\s*\d+)[:\-]\s*([^\n\r]+(?:\n[^\n\r]*?){0,10}?)(?=\n\s*(?:Figure|\*\*|$))',
r'(図\s*\d+)[:\-]\s*([^\n\r]+(?:\n[^\n\r]*?){0,10}?)(?=\n\s*(?:図|\*\*|$))',
# Pattern for abbreviated format: "Fig. 1", "Fig. 2", "Fig.1 Title", "Fig. 5 Title" etc.
r'(Fig\.\s*\d+)[:\-]?\s*([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))',
r'(Fig\s*\d+)[:\-]?\s*([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))',
# Pattern for "Fig.1 Title" (no space after period, no colon)
r'(Fig\.\d+)\s+([^\n\r]+(?:\n[^\n\r]*?){0,15}?)(?=\n\s*(?:Fig|Figure|\*\*|$))'
]
for pattern in direct_figure_patterns:
matches = re.findall(pattern, page_content, re.MULTILINE | re.IGNORECASE)
for figure_ref, description in matches:
fig_num_match = re.search(r'(\d+)', figure_ref)
if fig_num_match:
fig_num = int(fig_num_match.group(1))
fig_key = f"figure_{fig_num}"
if fig_key not in figures: # Don't overwrite detailed descriptions
clean_description = re.sub(r'\n\s*\n+', '\n', description.strip())
clean_description = re.sub(r' +', ' ', clean_description)
figures[fig_key] = {
"description": f"{figure_ref}: {clean_description}",
"page_number": page_number,
"figure_number": fig_num,
"title": ""
}
logger.info(f"Extracted direct figure reference {fig_num} from page {page_number}")
return figures
def _extract_figures_from_full_document(document_text: str, visual_elements: dict):
"""
Fallback: Extract any missed figures from entire document without page context.
Args:
document_text: Full document text
visual_elements: Existing visual elements dict to update
"""
# Look for any figure references we might have missed
figure_ref_pattern = r'(?:^|\n)\s*(?:Figure|Fig\.|Fig|図)\s*(\d+)[:\-]'
matches = re.finditer(figure_ref_pattern, document_text, re.MULTILINE | re.IGNORECASE)
for match in matches:
fig_num = int(match.group(1))
fig_key = f"figure_{fig_num}"
# Only add if not already found
if fig_key not in visual_elements:
# Try to extract description after the figure reference
start_pos = match.end()
# Look for next figure or end of section
end_match = re.search(r'(?:Figure|図)\s*\d+[:\-]', document_text[start_pos:], re.IGNORECASE)
if end_match:
description = document_text[start_pos:start_pos + end_match.start()].strip()
else:
# Take next 500 chars
description = document_text[start_pos:start_pos + 500].strip()
if description:
clean_description = re.sub(r'\n\s*\n+', '\n', description)
clean_description = re.sub(r' +', ' ', clean_description)
visual_elements[fig_key] = {
"description": clean_description[:200], # Limit length
"page_number": None, # Unknown
"figure_number": fig_num,
"title": ""
}
logger.info(f"Extracted fallback figure reference: {fig_key}")
def match_image_to_figure(image_id: str, visual_elements: dict, used_figures: set = None) -> tuple:
"""
Match image ID to figure description from visual elements using page-based matching.
Enhanced with conflict resolution to prevent duplicate mappings.
Args:
image_id: Image identifier (e.g., "page2_image1")
visual_elements: Dict of extracted visual elements (with page_number in each entry)
used_figures: Set of figure numbers already matched (for conflict resolution)
Returns:
Tuple of (figure_number, figure_description) or (None, "") if no match
"""
try:
if used_figures is None:
used_figures = set()
# Extract page number from image_id
page_match = re.search(r'page(\d+)', image_id)
if not page_match:
return (None, "")
image_page_num = int(page_match.group(1))
# Extract image index from image_id
img_match = re.search(r'image(\d+)', image_id)
img_index = int(img_match.group(1)) if img_match else 1
# Strategy 1: Match by actual page number (most accurate)
page_matches = []
all_figures_on_page = [] # Track all figures on this page (used or not)
for fig_key, fig_data in visual_elements.items():
# Handle both old format (string) and new format (dict)
if isinstance(fig_data, dict):
fig_page_num = fig_data.get("page_number")
fig_num = fig_data.get("figure_number")
description = fig_data.get("description", "")
else:
# Old format - try to extract from string
fig_num_match = re.search(r'figure_(\d+)', fig_key)
if not fig_num_match:
continue
fig_num = int(fig_num_match.group(1))
description = fig_data if isinstance(fig_data, str) else ""
fig_page_num = None # Unknown for old format
# Exact page match is best
if fig_page_num is not None and fig_page_num == image_page_num:
# Track all figures on this page
all_figures_on_page.append((fig_num, fig_key, description))
# Skip if already used (conflict resolution) - but we'll check later if we can reuse
if fig_num not in used_figures:
page_matches.append((fig_num, fig_key, description, img_index))
# If we have page matches, use image index to select the right one
if page_matches:
# Sort by figure number, then use image index
page_matches.sort(key=lambda x: x[0]) # Sort by figure number
# If multiple figures on same page, match by image index order
if len(page_matches) >= img_index:
best_match = page_matches[img_index - 1] # image1 = first figure, image2 = second, etc.
else:
best_match = page_matches[0] # Fallback to first match
# Mark as used
used_figures.add(best_match[0])
try:
logger.info(f"Matched {image_id} (page {image_page_num}, image {img_index}) to Figure {best_match[0]} on page {image_page_num}")
except UnicodeEncodeError:
logger.info(f"Matched {image_id} to figure (Unicode characters present)")
return (best_match[0], best_match[2])
# Solution 1: Relax conflict resolution for same-page images
# If no unused figures on this page, but all figures on page are used,
# allow reusing the last figure for remaining images on the same page
if not page_matches and all_figures_on_page:
# Check if all figures on this page are already used
all_used = all(fig_num in used_figures for fig_num, _, _ in all_figures_on_page)
if all_used:
# All figures on this page are used, but we have more images on this page
# Allow reusing the last figure (or first if only one) for remaining images
all_figures_on_page.sort(key=lambda x: x[0]) # Sort by figure number
# Use the last figure on the page for remaining images
# This handles the case: 1 figure, multiple images on same page
best_match = all_figures_on_page[-1] # Last figure on page
fig_num, fig_key, description = best_match
# Don't mark as used again (it's already used)
# But allow this match since it's on the same page
try:
logger.info(f"Matched {image_id} (page {image_page_num}, image {img_index}) to Figure {fig_num} on page {image_page_num} (reused - all figures on page already used)")
except UnicodeEncodeError:
logger.info(f"Matched {image_id} to figure (reused, Unicode characters present)")
return (fig_num, description)
# Strategy 2: Fallback - if no page match, try proximity (only if page numbers unknown)
potential_figures = []
for fig_key, fig_data in visual_elements.items():
if isinstance(fig_data, dict):
fig_num = fig_data.get("figure_number")
description = fig_data.get("description", "")
fig_page_num = fig_data.get("page_number")
else:
fig_num_match = re.search(r'figure_(\d+)', fig_key)
if not fig_num_match:
continue
fig_num = int(fig_num_match.group(1))
description = fig_data if isinstance(fig_data, str) else ""
fig_page_num = None
if fig_num is None:
continue
# Only use fallback if page number is unknown AND not already used
if fig_page_num is None and fig_num not in used_figures:
distance = abs(fig_num - (image_page_num + img_index))
if distance <= 2: # Tighter threshold
potential_figures.append((fig_num, fig_key, description, distance))
if potential_figures:
potential_figures.sort(key=lambda x: x[3]) # Sort by distance
best_match = potential_figures[0]
used_figures.add(best_match[0])
logger.warning(f"Using fallback matching for {image_id} to Figure {best_match[0]} (page number unknown)")
return (best_match[0], best_match[2])
return (None, "")
except Exception as e:
logger.error(f"Error matching image {image_id} to figure: {e}")
return (None, "")
async def merge_visual_elements_with_ai_summary(image_id: str, ai_summary: str, document_text: str, used_figures: set = None) -> tuple:
"""
Merge extracted visual elements with AI-generated image summary.
Enhanced to handle page numbers and conflict resolution.
Args:
image_id: Image identifier
ai_summary: AI-generated summary
document_text: Full document text
used_figures: Set of figure numbers already matched (for conflict resolution)
Returns:
Tuple of (enhanced_summary, figure_metadata_dict)
figure_metadata_dict contains: figure_number, figure_description, page_number, or None if no match
"""
try:
# Extract visual elements from document text (run in executor to avoid blocking)
try:
loop = asyncio.get_event_loop()
visual_elements = await loop.run_in_executor(None, extract_visual_elements_from_text, document_text)
except Exception as e:
# Fallback to synchronous if executor fails
logger.warning(f"Failed to run extract_visual_elements_from_text in executor: {e}, using sync")
visual_elements = extract_visual_elements_from_text(document_text)
if not visual_elements:
logger.info(f"No visual elements found for {image_id}, using AI summary only")
return (ai_summary, None)
# Initialize used_figures if not provided
if used_figures is None:
used_figures = set()
# Try to match image to figure description with conflict resolution
figure_num, matched_description = match_image_to_figure(image_id, visual_elements, used_figures)
# Get page number from visual elements if available
page_number = None
if figure_num is not None:
fig_key = f"figure_{figure_num}"
if fig_key in visual_elements:
fig_data = visual_elements[fig_key]
if isinstance(fig_data, dict):
page_number = fig_data.get("page_number")
figure_metadata = None
if matched_description and figure_num is not None:
# Create enhanced summary combining both sources
try:
enhanced_summary = f"""**Figure {figure_num} - Document Description:**
{matched_description}
**AI Visual Analysis:**
{ai_summary}"""
# Ensure the summary can be encoded properly
enhanced_summary.encode('utf-8')
# Create figure metadata
figure_metadata = {
"figure_number": figure_num,
"figure_key": f"figure_{figure_num}",
"figure_description": matched_description,
"image_id": image_id,
"page_number": page_number
}
logger.info(f"Enhanced summary created for {image_id} -> Figure {figure_num} (page {page_number}) using document + AI")
return (enhanced_summary, figure_metadata)
except UnicodeEncodeError as ue:
logger.warning(f"Unicode encoding issue for {image_id}, using AI summary only: {ue}")
return (ai_summary, None)
else:
logger.info(f"No matching figure found for {image_id}, using AI summary only")
return (ai_summary, None)
except Exception as e:
logger.error(f"Error merging visual elements for {image_id}: {e}")
return (ai_summary, None)