File size: 14,969 Bytes
6d0d692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4559a78
 
6d0d692
4559a78
6d0d692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
"""
preprocessing/chunker.py
=========================
ALZDETECT-AI β€” Enterprise Semantic Chunker.

WHAT:   Splits 19,637 cleaned papers into ~40,000 searchable chunks.
WHY:    Pinecone retrieves chunks, not full papers.
        A question about pTau217 should return the exact sentence
        mentioning pTau217 β€” not the entire abstract.
WHO:    Called by scripts/run_pipeline.py after cleaner.
        Output consumed only by vector_store/embedder.py
WHERE:  Reads  β†’ data/processed/cleaned_papers.json
        Writes β†’ data/processed/chunks.json
WHEN:   Once per plan, after cleaning, before embedding.

WORST-CASE DESIGN:
    - Abstract too short to chunk  β†’ single chunk, logged
    - Chunk text empty after split β†’ chunk skipped
    - Missing metadata on paper    β†’ defaults used, never crashes
    - Output dir missing           β†’ created automatically
    - Duplicate chunk IDs          β†’ caught by validator
"""

import json
import re
from pathlib import Path
from typing import Optional
from datetime import datetime

from pydantic import BaseModel, Field, field_validator, model_validator
from loguru import logger ## 
from tqdm import tqdm 

from configs.settings import get_settings 
from preprocessing.cleaner import CleanedPaper


# ── Single chunk model ────────────────────────────────────────────

class PaperChunk(BaseModel):
    """
    One chunk from one paper β€” the atomic unit of the RAG system.

    Analogy: One index card from the book editor.
    Contains: the content (text), the reference (pmid, chunk_idx),
    and enough metadata to filter in Pinecone (year, keywords, source).

    This is what gets embedded and stored in Pinecone.
    Every field here becomes either a vector or a metadata filter.
    """
    chunk_id:    str         = Field(..., description="Unique ID: pmid_chunk_N")
    pmid:        str         = Field(..., description="Parent paper PMID")
    chunk_idx:   int         = Field(..., ge=0, description="Position in paper")
    text:        str         = Field(..., min_length=20,
                                     description="Chunk text β€” what gets embedded")
    title:       str         = Field(..., description="Parent paper title")
    year:        Optional[int] = Field(default=None)
    keywords:    list[str]   = Field(default_factory=list)
    journal:     Optional[str] = Field(default=None)
    source_query: str        = Field(default="unknown")
    source:      str         = Field(default="pubmed",
                                     description="pubmed or adni")
    word_count:  int         = Field(default=0, description="Words in this chunk")

    @field_validator("chunk_id")
    @classmethod
    def chunk_id_format(cls, v: str) -> str:
        """
        Enforce chunk ID format: pmid_chunk_N
        Wrong format = Pinecone upsert will have inconsistent IDs.
        """
        if "_chunk_" not in v:
            raise ValueError(
                f"chunk_id '{v}' must contain '_chunk_'. "
                f"Format: 'pmid_chunk_N'"
            )
        return v

    @field_validator("text")
    @classmethod
    def text_has_content(cls, v: str) -> str:
        """Chunk must have real content β€” not just whitespace."""
        v = v.strip()
        if len(v.split()) < 5:
            raise ValueError(
                f"Chunk text too short ({len(v.split())} words) β€” skipped"
            )
        return v

    @model_validator(mode="after")
    def compute_word_count(self) -> "PaperChunk":
        """Auto-compute word count after all fields validated."""
        self.word_count = len(self.text.split())
        return self

    def to_dict(self) -> dict:
        return self.model_dump()

    def to_pinecone_record(self) -> dict:
        """
        Format for Pinecone upsert.
        Pinecone expects: {id, values (embedding), metadata}
        This method builds the metadata part.
        Embedding is added by embedder.py.
        """
        return {
            "id": self.chunk_id,
            "metadata": {
                "pmid":         self.pmid,
                "chunk_idx":    self.chunk_idx,
                "text":         self.text[:1000],  # Pinecone metadata limit
                "title":        self.title[:200],
                "year":         self.year,
                "keywords":     self.keywords[:10],  # limit metadata size
                "journal":      self.journal,
                "source_query": self.source_query,
                "source":       self.source,
                "word_count":   self.word_count,
            }
        }


# ── Chunk diagnostic model ────────────────────────────────────────

class ChunkDiagnostic(BaseModel):
    """
    RE inspector for chunking stage.

    Analogy: The editor's production report.
    How many index cards were created, what was skipped,
    what is the average card size.
    """
    total_papers:      int
    total_chunks:      int
    rejected_chunks:   int
    avg_chunks_per_paper: float
    avg_words_per_chunk:  float
    min_words:         int
    max_words:         int
    chunk_duration_secs: float
    output_path:       str

    def log_summary(self) -> None:
        logger.info("=" * 60)
        logger.info("[CHUNK-DIAGNOSTIC] Run complete")
        logger.info(f"  Papers processed  : {self.total_papers:,}")
        logger.info(f"  Total chunks      : {self.total_chunks:,}")
        logger.info(f"  Rejected chunks   : {self.rejected_chunks:,}")
        logger.info(f"  Avg chunks/paper  : {self.avg_chunks_per_paper:.1f}")
        logger.info(f"  Avg words/chunk   : {self.avg_words_per_chunk:.1f}")
        logger.info(f"  Min words         : {self.min_words}")
        logger.info(f"  Max words         : {self.max_words}")
        logger.info(f"  Duration          : {self.chunk_duration_secs:.1f}s")
        logger.info(f"  Saved to          : {self.output_path}")
        logger.info("=" * 60)


# ── Chunking functions ────────────────────────────────────────────

def split_into_chunks(
    text: str,
    chunk_size: int,
    overlap: int,
) -> list[str]:
    """
    Split text into overlapping word-based chunks.

    WHY word-based not character-based:
        Words are the natural unit for biomedical text.
        Character splits can cut mid-word β€” "Alzheimer" becomes
        "Alzhe" + "imer" β€” meaningless to the embedding model.

    WHY overlap:
        If a key sentence falls at the boundary between two chunks,
        overlap ensures it appears in both β€” so retrieval finds it
        regardless of which chunk is retrieved.

    Analogy: Photocopying pages of a book with 1 page overlap.
    You never miss a sentence that falls between two photocopies.

    Args:
        text       : full abstract text
        chunk_size : max words per chunk (from settings)
        overlap    : words shared between adjacent chunks

    Returns:
        list of chunk strings β€” at least 1, even for short abstracts
    """
    words = text.split()

    # Short abstracts β€” return as single chunk
    if len(words) <= chunk_size:
        return [text]

    chunks = []
    start  = 0
    step   = max(1, chunk_size - overlap)  # prevent infinite loop

    while start < len(words):
        end   = min(start + chunk_size, len(words))
        chunk = " ".join(words[start:end])
        chunks.append(chunk)
        if end == len(words):
            break
        start += step

    return chunks


def chunk_paper(
    paper: CleanedPaper,
    chunk_size: int,
    overlap: int,
) -> list[PaperChunk]:
    """
    Chunk one paper into validated PaperChunk objects.

    Worst-case handled:
        - Abstract too short β†’ single chunk returned
        - Chunk text invalid β†’ Pydantic rejects, chunk skipped
        - No chunks produced β†’ empty list, caller logs warning
    """
    raw_chunks = split_into_chunks(paper.abstract, chunk_size, overlap)
    validated  = []

    for idx, chunk_text in enumerate(raw_chunks):
        try:
            chunk = PaperChunk(
                chunk_id     = f"{paper.pmid}_chunk_{idx}",
                pmid         = paper.pmid,
                chunk_idx    = idx,
                text         = chunk_text,
                title        = paper.title,
                year         = paper.year,
                keywords     = paper.keywords,
                journal      = paper.journal,
                source_query = paper.source_query,
                source       = "pubmed",
            )
            validated.append(chunk)
        except Exception as e:
            logger.debug(
                f"[CHUNKER] Skipped chunk {idx} of PMID {paper.pmid}: {e}"
            )

    return validated


# ── Core chunker class ────────────────────────────────────────────

class PaperChunker:
    """
    Enterprise paper chunker.

    Analogy: The book editor team.
    Receives prepared patients (CleanedPaper objects),
    cuts each paper into index cards (PaperChunk objects),
    saves the full card catalogue (chunks.json).

    Usage:
        chunker    = PaperChunker()
        diagnostic = chunker.run()
    """

    def __init__(self) -> None:
        self.settings = get_settings()
        self._setup_paths()

    def _setup_paths(self) -> None:
        self.input_path  = (
            self.settings.processed_data_path.parent / "cleaned_papers.json"
        )
        self.output_path = self.settings.processed_data_path
        self.output_path.parent.mkdir(parents=True, exist_ok=True)
        logger.info(f"[CHUNKER] Input : {self.input_path}")
        logger.info(f"[CHUNKER] Output: {self.output_path}")
        logger.info(
            f"[CHUNKER] chunk_size={self.settings.chunk_size} | "
            f"overlap={self.settings.chunk_overlap}"
        )

    def _load_cleaned_papers(self) -> list[dict]:
        """Load cleaned papers β€” fail fast if file missing."""
        if not self.input_path.exists():
            logger.error(f"[CHUNKER] Input not found: {self.input_path}")
            raise FileNotFoundError(
                f"Run cleaner first. No file at: {self.input_path}"
            )
        with open(self.input_path, encoding="utf-8") as f:
            papers = json.load(f)
        logger.info(f"[CHUNKER] Loaded {len(papers):,} cleaned papers")
        return papers

    def run(self) -> ChunkDiagnostic:
        """
        Main entry point β€” chunks all cleaned papers.

        Flow:
            1. Load cleaned papers
            2. For each paper β†’ split into chunks
            3. Validate each chunk through PaperChunk
            4. Save all chunks to JSON
            5. Return ChunkDiagnostic
        """
        import time
        start_time = time.time()

        logger.info("[CHUNKER] Starting enterprise chunking")

        raw_papers     = self._load_cleaned_papers()
        all_chunks     = []
        rejected_total = 0
        word_counts    = []

        for raw in tqdm(raw_papers, desc="Chunking", unit="paper"):
            try:
                paper  = CleanedPaper(**raw)
                chunks = chunk_paper(
                    paper,
                    self.settings.chunk_size,
                    self.settings.chunk_overlap,
                )

                if not chunks:
                    logger.warning(
                        f"[CHUNKER] No chunks produced for PMID {paper.pmid}"
                    )
                    rejected_total += 1
                    continue

                for chunk in chunks:
                    all_chunks.append(chunk.to_dict())
                    word_counts.append(chunk.word_count)

            except Exception as e:
                logger.warning(f"[CHUNKER] Failed paper: {e}")
                rejected_total += 1

        # Save output
        try:
            with open(self.output_path, "w", encoding="utf-8") as f:
                json.dump(all_chunks, f, ensure_ascii=False, indent=2)
            logger.info(
                f"[CHUNKER] Saved {len(all_chunks):,} chunks β†’ {self.output_path}"
            )
        except Exception as e:
            logger.error(f"[CHUNKER] FATAL β€” could not save: {e}")
            raise

        duration = round(time.time() - start_time, 1)

        diagnostic = ChunkDiagnostic(
            total_papers         = len(raw_papers),
            total_chunks         = len(all_chunks),
            rejected_chunks      = rejected_total,
            avg_chunks_per_paper = round(len(all_chunks) / max(len(raw_papers), 1), 1),
            avg_words_per_chunk  = round(sum(word_counts) / max(len(word_counts), 1), 1),
            min_words            = min(word_counts) if word_counts else 0,
            max_words            = max(word_counts) if word_counts else 0,
            chunk_duration_secs  = duration,
            output_path          = str(self.output_path),
        )
        diagnostic.log_summary()
        return diagnostic


# ── RE probe ──────────────────────────────────────────────────────

def diagnose_chunks(filepath: Optional[str] = None) -> ChunkDiagnostic:
    """
    Reverse engineering probe for chunking stage.

    WHY: Run this when Pinecone retrieval feels wrong.
    Inspect chunks.json without re-running chunking.

    Usage:
        python -c "from preprocessing.chunker import diagnose_chunks; diagnose_chunks()"
    """
    import time
    settings = get_settings()
    path     = Path(filepath) if filepath else settings.processed_data_path

    if not path.exists():
        logger.error(f"[RE-CHUNKER] File not found: {path}")
        raise FileNotFoundError(f"No chunks at {path}. Run chunker first.")

    logger.info(f"[RE-CHUNKER] Inspecting: {path}")

    with open(path, encoding="utf-8") as f:
        chunks = json.load(f)

    word_counts = [len(c.get("text", "").split()) for c in chunks]

    diagnostic = ChunkDiagnostic(
        total_papers         = len(set(c.get("pmid") for c in chunks)),
        total_chunks         = len(chunks),
        rejected_chunks      = 0,
        avg_chunks_per_paper = round(
            len(chunks) / max(len(set(c.get("pmid") for c in chunks)), 1), 1
        ),
        avg_words_per_chunk  = round(
            sum(word_counts) / max(len(word_counts), 1), 1
        ),
        min_words            = min(word_counts) if word_counts else 0,
        max_words            = max(word_counts) if word_counts else 0,
        chunk_duration_secs  = 0.0,
        output_path          = str(path),
    )
    diagnostic.log_summary()
    return diagnostic