docmind / summarizer /key_points.py
AI Engineer
Initial commit for DocMind
6cca5b1
Raw
History Blame Contribute Delete
4.5 kB
"""
DocMind — Key Points Extraction
Extracts 5-7 key facts/claims from a document with source page
references, and deduplicates using cosine similarity on embeddings.
"""
import logging
from dataclasses import dataclass
from typing import List
import numpy as np
from pipeline.chunker import ChunkMetadata
from pipeline.llm import generate_summary
logger = logging.getLogger(__name__)
KEY_POINTS_PROMPT = (
"Extract the {n} most important facts or claims from the following document content. "
"Format each as a bullet point starting with '•'. "
"After each point, add the page reference in brackets like [PAGE X]. "
"Be specific and factual. Do not generalize or add interpretation."
)
@dataclass
class KeyPoint:
"""A single extracted key point."""
text: str
page_ref: str = ""
def extract_key_points(
chunks: List[ChunkMetadata],
embed_model=None,
n_points: int = 7,
batch_size: int = 5,
) -> List[KeyPoint]:
"""
Extract key points from document chunks in batches.
Strategy:
1. Process chunks in batches of 5 to stay within token limits
2. Each batch produces 2-3 key points
3. Merge all points and deduplicate via cosine similarity
Args:
chunks: All document chunks.
embed_model: Sentence-transformer model for deduplication.
n_points: Target number of key points (5-7).
batch_size: Number of chunks per LLM call.
Returns:
Deduplicated list of KeyPoint objects.
"""
if not chunks:
return []
all_points: List[KeyPoint] = []
# Process in batches
for i in range(0, len(chunks), batch_size):
batch = chunks[i : i + batch_size]
combined_text = "\n\n".join(
f"[PAGE {c.page_num}] {c.text}" for c in batch
)
prompt = KEY_POINTS_PROMPT.format(n=min(3, n_points))
try:
result = generate_summary(combined_text, prompt)
batch_points = _parse_bullet_points(result)
all_points.extend(batch_points)
except Exception as e:
logger.error("Key points extraction failed for batch %d: %s", i, e)
continue
logger.info("Extracted %d raw key points", len(all_points))
# Deduplicate
if embed_model and len(all_points) > 1:
all_points = deduplicate_points(all_points, embed_model, threshold=0.85)
# Trim to target count
return all_points[:n_points]
def _parse_bullet_points(text: str) -> List[KeyPoint]:
"""Parse bullet-pointed text into KeyPoint objects."""
points = []
for line in text.split("\n"):
line = line.strip()
if not line:
continue
# Remove leading bullet characters
for prefix in ("•", "-", "*", "·"):
if line.startswith(prefix):
line = line[len(prefix):].strip()
break
if not line or len(line) < 10:
continue
# Extract page reference if present
page_ref = ""
import re
page_match = re.search(r"\[PAGE\s+(\d+)\]", line, re.IGNORECASE)
if page_match:
page_ref = page_match.group(0)
line = line[:page_match.start()].strip()
points.append(KeyPoint(text=line, page_ref=page_ref))
return points
def deduplicate_points(
points: List[KeyPoint],
embed_model,
threshold: float = 0.85,
) -> List[KeyPoint]:
"""
Remove near-duplicate key points using cosine similarity.
Args:
points: List of extracted key points.
embed_model: Sentence-transformer model.
threshold: Cosine similarity threshold (above = duplicate).
Returns:
Deduplicated list of KeyPoints.
"""
if len(points) <= 1:
return points
texts = [p.text for p in points]
embeddings = embed_model.encode(texts, normalize_embeddings=True)
kept_indices = [0] # Always keep the first point
for i in range(1, len(points)):
is_duplicate = False
for j in kept_indices:
similarity = float(np.dot(embeddings[i], embeddings[j]))
if similarity >= threshold:
is_duplicate = True
break
if not is_duplicate:
kept_indices.append(i)
deduped = [points[i] for i in kept_indices]
removed = len(points) - len(deduped)
if removed > 0:
logger.info("Deduplicated key points: %d → %d (removed %d)", len(points), len(deduped), removed)
return deduped