| """ |
| DocMind — Key Points Extraction |
| |
| Extracts 5-7 key facts/claims from a document with source page |
| references, and deduplicates using cosine similarity on embeddings. |
| """ |
|
|
| import logging |
| from dataclasses import dataclass |
| from typing import List |
|
|
| import numpy as np |
|
|
| from pipeline.chunker import ChunkMetadata |
| from pipeline.llm import generate_summary |
|
|
| logger = logging.getLogger(__name__) |
|
|
| KEY_POINTS_PROMPT = ( |
| "Extract the {n} most important facts or claims from the following document content. " |
| "Format each as a bullet point starting with '•'. " |
| "After each point, add the page reference in brackets like [PAGE X]. " |
| "Be specific and factual. Do not generalize or add interpretation." |
| ) |
|
|
|
|
| @dataclass |
| class KeyPoint: |
| """A single extracted key point.""" |
| text: str |
| page_ref: str = "" |
|
|
|
|
| def extract_key_points( |
| chunks: List[ChunkMetadata], |
| embed_model=None, |
| n_points: int = 7, |
| batch_size: int = 5, |
| ) -> List[KeyPoint]: |
| """ |
| Extract key points from document chunks in batches. |
| |
| Strategy: |
| 1. Process chunks in batches of 5 to stay within token limits |
| 2. Each batch produces 2-3 key points |
| 3. Merge all points and deduplicate via cosine similarity |
| |
| Args: |
| chunks: All document chunks. |
| embed_model: Sentence-transformer model for deduplication. |
| n_points: Target number of key points (5-7). |
| batch_size: Number of chunks per LLM call. |
| |
| Returns: |
| Deduplicated list of KeyPoint objects. |
| """ |
| if not chunks: |
| return [] |
|
|
| all_points: List[KeyPoint] = [] |
|
|
| |
| for i in range(0, len(chunks), batch_size): |
| batch = chunks[i : i + batch_size] |
| combined_text = "\n\n".join( |
| f"[PAGE {c.page_num}] {c.text}" for c in batch |
| ) |
|
|
| prompt = KEY_POINTS_PROMPT.format(n=min(3, n_points)) |
|
|
| try: |
| result = generate_summary(combined_text, prompt) |
| batch_points = _parse_bullet_points(result) |
| all_points.extend(batch_points) |
| except Exception as e: |
| logger.error("Key points extraction failed for batch %d: %s", i, e) |
| continue |
|
|
| logger.info("Extracted %d raw key points", len(all_points)) |
|
|
| |
| if embed_model and len(all_points) > 1: |
| all_points = deduplicate_points(all_points, embed_model, threshold=0.85) |
|
|
| |
| return all_points[:n_points] |
|
|
|
|
| def _parse_bullet_points(text: str) -> List[KeyPoint]: |
| """Parse bullet-pointed text into KeyPoint objects.""" |
| points = [] |
| for line in text.split("\n"): |
| line = line.strip() |
| if not line: |
| continue |
| |
| for prefix in ("•", "-", "*", "·"): |
| if line.startswith(prefix): |
| line = line[len(prefix):].strip() |
| break |
|
|
| if not line or len(line) < 10: |
| continue |
|
|
| |
| page_ref = "" |
| import re |
| page_match = re.search(r"\[PAGE\s+(\d+)\]", line, re.IGNORECASE) |
| if page_match: |
| page_ref = page_match.group(0) |
| line = line[:page_match.start()].strip() |
|
|
| points.append(KeyPoint(text=line, page_ref=page_ref)) |
|
|
| return points |
|
|
|
|
| def deduplicate_points( |
| points: List[KeyPoint], |
| embed_model, |
| threshold: float = 0.85, |
| ) -> List[KeyPoint]: |
| """ |
| Remove near-duplicate key points using cosine similarity. |
| |
| Args: |
| points: List of extracted key points. |
| embed_model: Sentence-transformer model. |
| threshold: Cosine similarity threshold (above = duplicate). |
| |
| Returns: |
| Deduplicated list of KeyPoints. |
| """ |
| if len(points) <= 1: |
| return points |
|
|
| texts = [p.text for p in points] |
| embeddings = embed_model.encode(texts, normalize_embeddings=True) |
|
|
| kept_indices = [0] |
|
|
| for i in range(1, len(points)): |
| is_duplicate = False |
| for j in kept_indices: |
| similarity = float(np.dot(embeddings[i], embeddings[j])) |
| if similarity >= threshold: |
| is_duplicate = True |
| break |
| if not is_duplicate: |
| kept_indices.append(i) |
|
|
| deduped = [points[i] for i in kept_indices] |
| removed = len(points) - len(deduped) |
| if removed > 0: |
| logger.info("Deduplicated key points: %d → %d (removed %d)", len(points), len(deduped), removed) |
|
|
| return deduped |
|
|