| |
| """ |
| pluto/embedder.py β Semantic chunking via NVIDIA NIM embedding endpoint. |
| |
| Replaces heading-based splitting with cosine-similarity boundary detection. |
| Also provides context injection: every chunk gets a header describing where |
| it sits in the document, so extraction agents never see orphaned facts. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import re |
| import math |
| from typing import TYPE_CHECKING |
|
|
| import requests |
|
|
| if TYPE_CHECKING: |
| from pluto.doc_index import DocIndex |
|
|
|
|
| NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" |
| EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2" |
| SIMILARITY_THRESHOLD = 0.75 |
| MAX_CHUNK_CHARS = 1800 |
| MIN_CHUNK_CHARS = 200 |
|
|
|
|
| def embed_texts(texts: list[str]) -> list[list[float]]: |
| """ |
| Call NVIDIA NIM embedding endpoint. |
| Returns list of float vectors, one per input text. |
| """ |
| api_key = os.getenv("NVIDIA_API_KEY_EMBED") or os.getenv("NVIDIA_API_KEY", "") |
| if not api_key: |
| raise ValueError("NVIDIA_API_KEY_EMBED or NVIDIA_API_KEY not set") |
|
|
| response = requests.post( |
| f"{NVIDIA_BASE_URL}/embeddings", |
| headers={ |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| }, |
| json={ |
| "model": EMBED_MODEL, |
| "input": texts, |
| "input_type": "passage", |
| "encoding_format": "float", |
| }, |
| timeout=60, |
| ) |
| if response.status_code != 200: |
| raise RuntimeError(f"NVIDIA embeddings {response.status_code}: {response.text[:300]}") |
|
|
| data = response.json().get("data", []) |
| return [item.get("embedding", []) for item in data] |
|
|
|
|
| def cosine_similarity(a: list[float], b: list[float]) -> float: |
| dot = sum(x * y for x, y in zip(a, b)) |
| mag_a = math.sqrt(sum(x * x for x in a)) |
| mag_b = math.sqrt(sum(x * x for x in b)) |
| if mag_a == 0 or mag_b == 0: |
| return 0.0 |
| return dot / (mag_a * mag_b) |
|
|
|
|
| def semantic_split(text: str) -> list[str]: |
| """ |
| Split document text into semantically coherent chunks. |
| |
| Algorithm: |
| 1. Split into sentences. |
| 2. Embed each sentence (batched). |
| 3. Where consecutive sentence similarity drops below threshold, |
| mark a boundary. |
| 4. Merge sentences within boundaries, respecting MAX_CHUNK_CHARS. |
| |
| Falls back to paragraph splitting if NVIDIA key is not available. |
| """ |
| api_key = os.getenv("NVIDIA_API_KEY_EMBED") or os.getenv("NVIDIA_API_KEY", "") |
| if not api_key: |
| |
| return _paragraph_split(text) |
|
|
| sentences = _split_sentences(text) |
| if len(sentences) <= 3: |
| return [text.strip()] |
|
|
| |
| embeddings: list[list[float]] = [] |
| batch_size = 50 |
| for i in range(0, len(sentences), batch_size): |
| batch = sentences[i:i + batch_size] |
| try: |
| embeddings.extend(embed_texts(batch)) |
| except Exception: |
| |
| return _paragraph_split(text) |
|
|
| |
| boundaries: list[int] = [0] |
| for i in range(1, len(sentences)): |
| sim = cosine_similarity(embeddings[i - 1], embeddings[i]) |
| if sim < SIMILARITY_THRESHOLD: |
| boundaries.append(i) |
| boundaries.append(len(sentences)) |
|
|
| |
| chunks: list[str] = [] |
| for b_idx in range(len(boundaries) - 1): |
| start = boundaries[b_idx] |
| end = boundaries[b_idx + 1] |
| segment = " ".join(sentences[start:end]).strip() |
|
|
| if len(segment) < MIN_CHUNK_CHARS and chunks: |
| |
| chunks[-1] = chunks[-1] + " " + segment |
| elif len(segment) > MAX_CHUNK_CHARS: |
| |
| chunks.extend(_paragraph_split(segment)) |
| else: |
| chunks.append(segment) |
|
|
| return [c.strip() for c in chunks if c.strip()] |
|
|
|
|
| def inject_context_headers( |
| chunks: list[str], |
| doc_id: str, |
| doc_index: "DocIndex | None" = None, |
| ) -> list[str]: |
| """ |
| Prepend a context header to each chunk so extraction agents |
| know where the chunk sits in the document. |
| |
| Format: |
| [Context | doc: X | chunk: C3 | section: Results | prev_topic: method] |
| <original chunk text> |
| |
| This eliminates orphaned-fact hallucination: the agent always knows |
| what section it's reading and what came before. |
| """ |
| result: list[str] = [] |
|
|
| for i, chunk in enumerate(chunks): |
| chunk_id = f"C{i}" |
|
|
| |
| section = "unknown" |
| prev_topic = "" |
| if doc_index and doc_index.has_doc(doc_id): |
| topic_map = doc_index.get_chunk_topics(doc_id) |
| topics = topic_map.get(chunk_id, []) |
| if topics: |
| section = topics[-1] if topics else "unknown" |
| if i > 0: |
| prev_topics = topic_map.get(f"C{i-1}", []) |
| prev_topic = prev_topics[0] if prev_topics else "" |
|
|
| header_parts = [f"doc:{doc_id}", f"chunk:{chunk_id}", f"section:{section}"] |
| if prev_topic: |
| header_parts.append(f"prev:{prev_topic}") |
|
|
| header = "[Context | " + " | ".join(header_parts) + "]\n" |
| result.append(header + chunk) |
|
|
| return result |
|
|
|
|
| |
|
|
| def _split_sentences(text: str) -> list[str]: |
| """Naive but fast sentence splitter.""" |
| raw = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) |
| return [s.strip() for s in raw if s.strip()] |
|
|
|
|
| def _paragraph_split(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]: |
| """Fallback: split on double newlines then merge to max_chars.""" |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] |
| chunks: list[str] = [] |
| current = "" |
| for para in paras: |
| if len(current) + len(para) + 2 > max_chars and current: |
| chunks.append(current) |
| current = para |
| else: |
| current = (current + "\n\n" + para).strip() if current else para |
| if current: |
| chunks.append(current) |
| return chunks if chunks else [text] |
|
|