uhc-policy-chatbot / scraper /create_rag_chunks.py
Mayank Patel
Initial deployment: UHC Medical Policy Chatbot
5c32ed1
import json
import re
import hashlib
from tqdm import tqdm
INPUT_FILE = "data/processed/extracted_sections.json"
OUTPUT_FILE = "data/processed/rag_chunks.json"
TARGET_CHUNK_TOKENS = 400
MAX_CHUNK_TOKENS = 600
OVERLAP_SENTENCES = 2
LOW_VALUE_SECTIONS = {"References", "U.S. Food and Drug Administration"}
BOILERPLATE_SECTIONS = {"Instructions for Use", "Policy History/Revision Information"}
def estimate_tokens(text):
return int(len(text.split()) * 1.3)
def split_sentences(text):
parts = re.split(r'(?<=[.;])\s+(?=[A-Z])', text)
sentences = []
for part in parts:
if estimate_tokens(part) > MAX_CHUNK_TOKENS:
sub_parts = re.split(r'(?<=[:;])\s+', part)
sentences.extend(sub_parts)
else:
sentences.append(part)
return [s.strip() for s in sentences if s.strip()]
def chunk_id(policy, section, idx):
raw = f"{policy}__{section}__{idx}"
return hashlib.md5(raw.encode()).hexdigest()[:12]
def chunk_by_criteria(text, policy_name):
blocks = re.split(
r'\n\n(?=(?:The following|For (?:initial|continuation|subsequent|revision|replacement)|'
r'(?:An?|The)\s+\w.*?is (?:proven|unproven|medically necessary|not medically)|'
r'(?:Multiplex|Implantable|Removable|Emergency|Non-Surgical|Surgical)))',
text
)
if len(blocks) <= 1:
blocks = re.split(r'\n\n', text)
result = []
current = []
current_tokens = 0
for block in blocks:
block = block.strip()
if not block:
continue
block_tokens = estimate_tokens(block)
if block_tokens > MAX_CHUNK_TOKENS:
if current:
result.append("\n\n".join(current))
current = []
current_tokens = 0
sents = split_sentences(block)
sent_group = []
sent_tokens = 0
for sent in sents:
st = estimate_tokens(sent)
if sent_tokens + st > TARGET_CHUNK_TOKENS and sent_group:
result.append(" ".join(sent_group))
overlap = sent_group[-OVERLAP_SENTENCES:] if len(sent_group) > OVERLAP_SENTENCES else []
sent_group = overlap
sent_tokens = sum(estimate_tokens(s) for s in sent_group)
sent_group.append(sent)
sent_tokens += st
if sent_group:
result.append(" ".join(sent_group))
elif current_tokens + block_tokens > TARGET_CHUNK_TOKENS and current:
result.append("\n\n".join(current))
current = [block]
current_tokens = block_tokens
else:
current.append(block)
current_tokens += block_tokens
if current:
result.append("\n\n".join(current))
return result
def chunk_code_table(text):
lines = text.split("\n")
chunks = []
current_lines = []
current_tokens = 0
header_line = None
for line in lines:
stripped = line.strip()
if not stripped:
continue
if re.match(r"^(?:CPT|HCPCS|Diagnosis|ICD-10)\s+(?:Code|Description)", stripped, re.IGNORECASE):
header_line = stripped
continue
if re.match(r"^The following list\(s\)", stripped):
continue
if re.match(r"^CPT®?\s+is a registered", stripped):
continue
if re.match(r"^Listing of a code", stripped):
continue
line_tokens = estimate_tokens(stripped)
if current_tokens + line_tokens > TARGET_CHUNK_TOKENS and current_lines:
chunk_text = "\n".join(current_lines)
if header_line:
chunk_text = header_line + "\n" + chunk_text
chunks.append(chunk_text)
current_lines = []
current_tokens = 0
current_lines.append(stripped)
current_tokens += line_tokens
if current_lines:
chunk_text = "\n".join(current_lines)
if header_line:
chunk_text = header_line + "\n" + chunk_text
chunks.append(chunk_text)
return chunks
def chunk_clinical_evidence(text):
study_splits = re.split(
r'\n\n(?=(?:[A-Z][a-z]+(?:\s+(?:et al\.|and|&))?.*?\(\d{4}\))|'
r'(?:A\s+(?:phase|prospective|retrospective|randomized|multicenter|systematic|meta-analysis|Cochrane))|'
r'(?:Professional Societies|American|European|National|International))',
text
)
if len(study_splits) <= 1:
study_splits = text.split("\n\n")
chunks = []
current = []
current_tokens = 0
for block in study_splits:
block = block.strip()
if not block:
continue
block_tokens = estimate_tokens(block)
if block_tokens > MAX_CHUNK_TOKENS:
if current:
chunks.append("\n\n".join(current))
current = []
current_tokens = 0
sents = split_sentences(block)
sent_group = []
sent_tokens = 0
for sent in sents:
st = estimate_tokens(sent)
if sent_tokens + st > TARGET_CHUNK_TOKENS and sent_group:
chunks.append(" ".join(sent_group))
overlap = sent_group[-OVERLAP_SENTENCES:] if len(sent_group) > OVERLAP_SENTENCES else []
sent_group = overlap
sent_tokens = sum(estimate_tokens(s) for s in sent_group)
sent_group.append(sent)
sent_tokens += st
if sent_group:
chunks.append(" ".join(sent_group))
elif current_tokens + block_tokens > MAX_CHUNK_TOKENS and current:
chunks.append("\n\n".join(current))
current = [block]
current_tokens = block_tokens
else:
current.append(block)
current_tokens += block_tokens
if current:
chunks.append("\n\n".join(current))
return chunks
def chunk_section(section_name, content, policy_name):
if section_name in BOILERPLATE_SECTIONS:
return []
if section_name == "Applicable Codes" or section_name == "Coverage Summary":
return chunk_code_table(content)
if section_name == "Clinical Evidence":
return chunk_clinical_evidence(content)
if section_name in ("Coverage Rationale", "Application", "Definitions",
"Documentation Requirements", "Medical Records Documentation Used for Reviews"):
return chunk_by_criteria(content, policy_name)
tokens = estimate_tokens(content)
if tokens <= TARGET_CHUNK_TOKENS:
return [content]
return chunk_by_criteria(content, policy_name)
def main():
with open(INPUT_FILE, "r", encoding="utf-8") as f:
policies = json.load(f)
all_chunks = []
for policy in tqdm(policies, desc="Creating chunks"):
policy_name = policy["policy_name"]
policy_number = policy.get("policy_number", "")
effective_date = policy.get("effective_date", "")
plan_type = policy.get("plan_type", "")
doc_type = policy.get("doc_type", "")
for section_data in policy["sections"]:
section_name = section_data["section"]
content = section_data["content"]
page_start = section_data.get("page_start", 0)
page_end = section_data.get("page_end", 0)
if section_name in BOILERPLATE_SECTIONS:
continue
text_chunks = chunk_section(section_name, content, policy_name)
for idx, chunk_text in enumerate(text_chunks):
chunk_text = chunk_text.strip()
if not chunk_text or len(chunk_text) < 20:
continue
all_chunks.append({
"id": chunk_id(policy_name, section_name, idx),
"policy_name": policy_name,
"policy_number": policy_number,
"effective_date": effective_date,
"plan_type": plan_type,
"doc_type": doc_type,
"section": section_name,
"page_start": page_start,
"page_end": page_end,
"chunk_index": idx,
"total_chunks_in_section": len(text_chunks),
"text": chunk_text,
})
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
json.dump(all_chunks, f, indent=2, ensure_ascii=False)
print(f"Total chunks: {len(all_chunks)}")
print(f"Policies processed: {len(policies)}")
print(f"Saved to: {OUTPUT_FILE}")
section_counts = {}
for c in all_chunks:
section_counts[c["section"]] = section_counts.get(c["section"], 0) + 1
print("\nChunks per section:")
for sec, count in sorted(section_counts.items(), key=lambda x: -x[1]):
print(f" {sec}: {count}")
if __name__ == "__main__":
main()