Spaces:
Running
Running
| """ | |
| preprocessing/chunker.py | |
| ========================= | |
| ALZDETECT-AI β Enterprise Semantic Chunker. | |
| WHAT: Splits 19,637 cleaned papers into ~40,000 searchable chunks. | |
| WHY: Pinecone retrieves chunks, not full papers. | |
| A question about pTau217 should return the exact sentence | |
| mentioning pTau217 β not the entire abstract. | |
| WHO: Called by scripts/run_pipeline.py after cleaner. | |
| Output consumed only by vector_store/embedder.py | |
| WHERE: Reads β data/processed/cleaned_papers.json | |
| Writes β data/processed/chunks.json | |
| WHEN: Once per plan, after cleaning, before embedding. | |
| WORST-CASE DESIGN: | |
| - Abstract too short to chunk β single chunk, logged | |
| - Chunk text empty after split β chunk skipped | |
| - Missing metadata on paper β defaults used, never crashes | |
| - Output dir missing β created automatically | |
| - Duplicate chunk IDs β caught by validator | |
| """ | |
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Optional | |
| from datetime import datetime | |
| from pydantic import BaseModel, Field, field_validator, model_validator | |
| from loguru import logger ## | |
| from tqdm import tqdm | |
| from configs.settings import get_settings | |
| from preprocessing.cleaner import CleanedPaper | |
| # ββ Single chunk model ββββββββββββββββββββββββββββββββββββββββββββ | |
| class PaperChunk(BaseModel): | |
| """ | |
| One chunk from one paper β the atomic unit of the RAG system. | |
| Analogy: One index card from the book editor. | |
| Contains: the content (text), the reference (pmid, chunk_idx), | |
| and enough metadata to filter in Pinecone (year, keywords, source). | |
| This is what gets embedded and stored in Pinecone. | |
| Every field here becomes either a vector or a metadata filter. | |
| """ | |
| chunk_id: str = Field(..., description="Unique ID: pmid_chunk_N") | |
| pmid: str = Field(..., description="Parent paper PMID") | |
| chunk_idx: int = Field(..., ge=0, description="Position in paper") | |
| text: str = Field(..., min_length=20, | |
| description="Chunk text β what gets embedded") | |
| title: str = Field(..., description="Parent paper title") | |
| year: Optional[int] = Field(default=None) | |
| keywords: list[str] = Field(default_factory=list) | |
| journal: Optional[str] = Field(default=None) | |
| source_query: str = Field(default="unknown") | |
| source: str = Field(default="pubmed", | |
| description="pubmed or adni") | |
| word_count: int = Field(default=0, description="Words in this chunk") | |
| def chunk_id_format(cls, v: str) -> str: | |
| """ | |
| Enforce chunk ID format: pmid_chunk_N | |
| Wrong format = Pinecone upsert will have inconsistent IDs. | |
| """ | |
| if "_chunk_" not in v: | |
| raise ValueError( | |
| f"chunk_id '{v}' must contain '_chunk_'. " | |
| f"Format: 'pmid_chunk_N'" | |
| ) | |
| return v | |
| def text_has_content(cls, v: str) -> str: | |
| """Chunk must have real content β not just whitespace.""" | |
| v = v.strip() | |
| if len(v.split()) < 5: | |
| raise ValueError( | |
| f"Chunk text too short ({len(v.split())} words) β skipped" | |
| ) | |
| return v | |
| def compute_word_count(self) -> "PaperChunk": | |
| """Auto-compute word count after all fields validated.""" | |
| self.word_count = len(self.text.split()) | |
| return self | |
| def to_dict(self) -> dict: | |
| return self.model_dump() | |
| def to_pinecone_record(self) -> dict: | |
| """ | |
| Format for Pinecone upsert. | |
| Pinecone expects: {id, values (embedding), metadata} | |
| This method builds the metadata part. | |
| Embedding is added by embedder.py. | |
| """ | |
| return { | |
| "id": self.chunk_id, | |
| "metadata": { | |
| "pmid": self.pmid, | |
| "chunk_idx": self.chunk_idx, | |
| "text": self.text[:1000], # Pinecone metadata limit | |
| "title": self.title[:200], | |
| "year": self.year, | |
| "keywords": self.keywords[:10], # limit metadata size | |
| "journal": self.journal, | |
| "source_query": self.source_query, | |
| "source": self.source, | |
| "word_count": self.word_count, | |
| } | |
| } | |
| # ββ Chunk diagnostic model ββββββββββββββββββββββββββββββββββββββββ | |
| class ChunkDiagnostic(BaseModel): | |
| """ | |
| RE inspector for chunking stage. | |
| Analogy: The editor's production report. | |
| How many index cards were created, what was skipped, | |
| what is the average card size. | |
| """ | |
| total_papers: int | |
| total_chunks: int | |
| rejected_chunks: int | |
| avg_chunks_per_paper: float | |
| avg_words_per_chunk: float | |
| min_words: int | |
| max_words: int | |
| chunk_duration_secs: float | |
| output_path: str | |
| def log_summary(self) -> None: | |
| logger.info("=" * 60) | |
| logger.info("[CHUNK-DIAGNOSTIC] Run complete") | |
| logger.info(f" Papers processed : {self.total_papers:,}") | |
| logger.info(f" Total chunks : {self.total_chunks:,}") | |
| logger.info(f" Rejected chunks : {self.rejected_chunks:,}") | |
| logger.info(f" Avg chunks/paper : {self.avg_chunks_per_paper:.1f}") | |
| logger.info(f" Avg words/chunk : {self.avg_words_per_chunk:.1f}") | |
| logger.info(f" Min words : {self.min_words}") | |
| logger.info(f" Max words : {self.max_words}") | |
| logger.info(f" Duration : {self.chunk_duration_secs:.1f}s") | |
| logger.info(f" Saved to : {self.output_path}") | |
| logger.info("=" * 60) | |
| # ββ Chunking functions ββββββββββββββββββββββββββββββββββββββββββββ | |
| def split_into_chunks( | |
| text: str, | |
| chunk_size: int, | |
| overlap: int, | |
| ) -> list[str]: | |
| """ | |
| Split text into overlapping word-based chunks. | |
| WHY word-based not character-based: | |
| Words are the natural unit for biomedical text. | |
| Character splits can cut mid-word β "Alzheimer" becomes | |
| "Alzhe" + "imer" β meaningless to the embedding model. | |
| WHY overlap: | |
| If a key sentence falls at the boundary between two chunks, | |
| overlap ensures it appears in both β so retrieval finds it | |
| regardless of which chunk is retrieved. | |
| Analogy: Photocopying pages of a book with 1 page overlap. | |
| You never miss a sentence that falls between two photocopies. | |
| Args: | |
| text : full abstract text | |
| chunk_size : max words per chunk (from settings) | |
| overlap : words shared between adjacent chunks | |
| Returns: | |
| list of chunk strings β at least 1, even for short abstracts | |
| """ | |
| words = text.split() | |
| # Short abstracts β return as single chunk | |
| if len(words) <= chunk_size: | |
| return [text] | |
| chunks = [] | |
| start = 0 | |
| step = max(1, chunk_size - overlap) # prevent infinite loop | |
| while start < len(words): | |
| end = min(start + chunk_size, len(words)) | |
| chunk = " ".join(words[start:end]) | |
| chunks.append(chunk) | |
| if end == len(words): | |
| break | |
| start += step | |
| return chunks | |
| def chunk_paper( | |
| paper: CleanedPaper, | |
| chunk_size: int, | |
| overlap: int, | |
| ) -> list[PaperChunk]: | |
| """ | |
| Chunk one paper into validated PaperChunk objects. | |
| Worst-case handled: | |
| - Abstract too short β single chunk returned | |
| - Chunk text invalid β Pydantic rejects, chunk skipped | |
| - No chunks produced β empty list, caller logs warning | |
| """ | |
| raw_chunks = split_into_chunks(paper.abstract, chunk_size, overlap) | |
| validated = [] | |
| for idx, chunk_text in enumerate(raw_chunks): | |
| try: | |
| chunk = PaperChunk( | |
| chunk_id = f"{paper.pmid}_chunk_{idx}", | |
| pmid = paper.pmid, | |
| chunk_idx = idx, | |
| text = chunk_text, | |
| title = paper.title, | |
| year = paper.year, | |
| keywords = paper.keywords, | |
| journal = paper.journal, | |
| source_query = paper.source_query, | |
| source = "pubmed", | |
| ) | |
| validated.append(chunk) | |
| except Exception as e: | |
| logger.debug( | |
| f"[CHUNKER] Skipped chunk {idx} of PMID {paper.pmid}: {e}" | |
| ) | |
| return validated | |
| # ββ Core chunker class ββββββββββββββββββββββββββββββββββββββββββββ | |
| class PaperChunker: | |
| """ | |
| Enterprise paper chunker. | |
| Analogy: The book editor team. | |
| Receives prepared patients (CleanedPaper objects), | |
| cuts each paper into index cards (PaperChunk objects), | |
| saves the full card catalogue (chunks.json). | |
| Usage: | |
| chunker = PaperChunker() | |
| diagnostic = chunker.run() | |
| """ | |
| def __init__(self) -> None: | |
| self.settings = get_settings() | |
| self._setup_paths() | |
| def _setup_paths(self) -> None: | |
| self.input_path = ( | |
| self.settings.processed_data_path.parent / "cleaned_papers.json" | |
| ) | |
| self.output_path = self.settings.processed_data_path | |
| self.output_path.parent.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"[CHUNKER] Input : {self.input_path}") | |
| logger.info(f"[CHUNKER] Output: {self.output_path}") | |
| logger.info( | |
| f"[CHUNKER] chunk_size={self.settings.chunk_size} | " | |
| f"overlap={self.settings.chunk_overlap}" | |
| ) | |
| def _load_cleaned_papers(self) -> list[dict]: | |
| """Load cleaned papers β fail fast if file missing.""" | |
| if not self.input_path.exists(): | |
| logger.error(f"[CHUNKER] Input not found: {self.input_path}") | |
| raise FileNotFoundError( | |
| f"Run cleaner first. No file at: {self.input_path}" | |
| ) | |
| with open(self.input_path, encoding="utf-8") as f: | |
| papers = json.load(f) | |
| logger.info(f"[CHUNKER] Loaded {len(papers):,} cleaned papers") | |
| return papers | |
| def run(self) -> ChunkDiagnostic: | |
| """ | |
| Main entry point β chunks all cleaned papers. | |
| Flow: | |
| 1. Load cleaned papers | |
| 2. For each paper β split into chunks | |
| 3. Validate each chunk through PaperChunk | |
| 4. Save all chunks to JSON | |
| 5. Return ChunkDiagnostic | |
| """ | |
| import time | |
| start_time = time.time() | |
| logger.info("[CHUNKER] Starting enterprise chunking") | |
| raw_papers = self._load_cleaned_papers() | |
| all_chunks = [] | |
| rejected_total = 0 | |
| word_counts = [] | |
| for raw in tqdm(raw_papers, desc="Chunking", unit="paper"): | |
| try: | |
| paper = CleanedPaper(**raw) | |
| chunks = chunk_paper( | |
| paper, | |
| self.settings.chunk_size, | |
| self.settings.chunk_overlap, | |
| ) | |
| if not chunks: | |
| logger.warning( | |
| f"[CHUNKER] No chunks produced for PMID {paper.pmid}" | |
| ) | |
| rejected_total += 1 | |
| continue | |
| for chunk in chunks: | |
| all_chunks.append(chunk.to_dict()) | |
| word_counts.append(chunk.word_count) | |
| except Exception as e: | |
| logger.warning(f"[CHUNKER] Failed paper: {e}") | |
| rejected_total += 1 | |
| # Save output | |
| try: | |
| with open(self.output_path, "w", encoding="utf-8") as f: | |
| json.dump(all_chunks, f, ensure_ascii=False, indent=2) | |
| logger.info( | |
| f"[CHUNKER] Saved {len(all_chunks):,} chunks β {self.output_path}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"[CHUNKER] FATAL β could not save: {e}") | |
| raise | |
| duration = round(time.time() - start_time, 1) | |
| diagnostic = ChunkDiagnostic( | |
| total_papers = len(raw_papers), | |
| total_chunks = len(all_chunks), | |
| rejected_chunks = rejected_total, | |
| avg_chunks_per_paper = round(len(all_chunks) / max(len(raw_papers), 1), 1), | |
| avg_words_per_chunk = round(sum(word_counts) / max(len(word_counts), 1), 1), | |
| min_words = min(word_counts) if word_counts else 0, | |
| max_words = max(word_counts) if word_counts else 0, | |
| chunk_duration_secs = duration, | |
| output_path = str(self.output_path), | |
| ) | |
| diagnostic.log_summary() | |
| return diagnostic | |
| # ββ RE probe ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def diagnose_chunks(filepath: Optional[str] = None) -> ChunkDiagnostic: | |
| """ | |
| Reverse engineering probe for chunking stage. | |
| WHY: Run this when Pinecone retrieval feels wrong. | |
| Inspect chunks.json without re-running chunking. | |
| Usage: | |
| python -c "from preprocessing.chunker import diagnose_chunks; diagnose_chunks()" | |
| """ | |
| import time | |
| settings = get_settings() | |
| path = Path(filepath) if filepath else settings.processed_data_path | |
| if not path.exists(): | |
| logger.error(f"[RE-CHUNKER] File not found: {path}") | |
| raise FileNotFoundError(f"No chunks at {path}. Run chunker first.") | |
| logger.info(f"[RE-CHUNKER] Inspecting: {path}") | |
| with open(path, encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| word_counts = [len(c.get("text", "").split()) for c in chunks] | |
| diagnostic = ChunkDiagnostic( | |
| total_papers = len(set(c.get("pmid") for c in chunks)), | |
| total_chunks = len(chunks), | |
| rejected_chunks = 0, | |
| avg_chunks_per_paper = round( | |
| len(chunks) / max(len(set(c.get("pmid") for c in chunks)), 1), 1 | |
| ), | |
| avg_words_per_chunk = round( | |
| sum(word_counts) / max(len(word_counts), 1), 1 | |
| ), | |
| min_words = min(word_counts) if word_counts else 0, | |
| max_words = max(word_counts) if word_counts else 0, | |
| chunk_duration_secs = 0.0, | |
| output_path = str(path), | |
| ) | |
| diagnostic.log_summary() | |
| return diagnostic |