""" DocMind — Key Points Extraction Extracts 5-7 key facts/claims from a document with source page references, and deduplicates using cosine similarity on embeddings. """ import logging from dataclasses import dataclass from typing import List import numpy as np from pipeline.chunker import ChunkMetadata from pipeline.llm import generate_summary logger = logging.getLogger(__name__) KEY_POINTS_PROMPT = ( "Extract the {n} most important facts or claims from the following document content. " "Format each as a bullet point starting with '•'. " "After each point, add the page reference in brackets like [PAGE X]. " "Be specific and factual. Do not generalize or add interpretation." ) @dataclass class KeyPoint: """A single extracted key point.""" text: str page_ref: str = "" def extract_key_points( chunks: List[ChunkMetadata], embed_model=None, n_points: int = 7, batch_size: int = 5, ) -> List[KeyPoint]: """ Extract key points from document chunks in batches. Strategy: 1. Process chunks in batches of 5 to stay within token limits 2. Each batch produces 2-3 key points 3. Merge all points and deduplicate via cosine similarity Args: chunks: All document chunks. embed_model: Sentence-transformer model for deduplication. n_points: Target number of key points (5-7). batch_size: Number of chunks per LLM call. Returns: Deduplicated list of KeyPoint objects. """ if not chunks: return [] all_points: List[KeyPoint] = [] # Process in batches for i in range(0, len(chunks), batch_size): batch = chunks[i : i + batch_size] combined_text = "\n\n".join( f"[PAGE {c.page_num}] {c.text}" for c in batch ) prompt = KEY_POINTS_PROMPT.format(n=min(3, n_points)) try: result = generate_summary(combined_text, prompt) batch_points = _parse_bullet_points(result) all_points.extend(batch_points) except Exception as e: logger.error("Key points extraction failed for batch %d: %s", i, e) continue logger.info("Extracted %d raw key points", len(all_points)) # Deduplicate if embed_model and len(all_points) > 1: all_points = deduplicate_points(all_points, embed_model, threshold=0.85) # Trim to target count return all_points[:n_points] def _parse_bullet_points(text: str) -> List[KeyPoint]: """Parse bullet-pointed text into KeyPoint objects.""" points = [] for line in text.split("\n"): line = line.strip() if not line: continue # Remove leading bullet characters for prefix in ("•", "-", "*", "·"): if line.startswith(prefix): line = line[len(prefix):].strip() break if not line or len(line) < 10: continue # Extract page reference if present page_ref = "" import re page_match = re.search(r"\[PAGE\s+(\d+)\]", line, re.IGNORECASE) if page_match: page_ref = page_match.group(0) line = line[:page_match.start()].strip() points.append(KeyPoint(text=line, page_ref=page_ref)) return points def deduplicate_points( points: List[KeyPoint], embed_model, threshold: float = 0.85, ) -> List[KeyPoint]: """ Remove near-duplicate key points using cosine similarity. Args: points: List of extracted key points. embed_model: Sentence-transformer model. threshold: Cosine similarity threshold (above = duplicate). Returns: Deduplicated list of KeyPoints. """ if len(points) <= 1: return points texts = [p.text for p in points] embeddings = embed_model.encode(texts, normalize_embeddings=True) kept_indices = [0] # Always keep the first point for i in range(1, len(points)): is_duplicate = False for j in kept_indices: similarity = float(np.dot(embeddings[i], embeddings[j])) if similarity >= threshold: is_duplicate = True break if not is_duplicate: kept_indices.append(i) deduped = [points[i] for i in kept_indices] removed = len(points) - len(deduped) if removed > 0: logger.info("Deduplicated key points: %d → %d (removed %d)", len(points), len(deduped), removed) return deduped