File size: 6,477 Bytes
23cdeed
66ad25b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- coding: utf-8 -*-
"""
pluto/embedder.py β€” Semantic chunking via NVIDIA NIM embedding endpoint.

Replaces heading-based splitting with cosine-similarity boundary detection.
Also provides context injection: every chunk gets a header describing where
it sits in the document, so extraction agents never see orphaned facts.
"""

from __future__ import annotations

import os
import re
import math
from typing import TYPE_CHECKING

import requests

if TYPE_CHECKING:
    from pluto.doc_index import DocIndex


NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1"
EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2"
SIMILARITY_THRESHOLD = 0.75   # drop below this β†’ new chunk boundary
MAX_CHUNK_CHARS = 1800
MIN_CHUNK_CHARS = 200


def embed_texts(texts: list[str]) -> list[list[float]]:
    """
    Call NVIDIA NIM embedding endpoint.
    Returns list of float vectors, one per input text.
    """
    api_key = os.getenv("NVIDIA_API_KEY_EMBED") or os.getenv("NVIDIA_API_KEY", "")
    if not api_key:
        raise ValueError("NVIDIA_API_KEY_EMBED or NVIDIA_API_KEY not set")

    response = requests.post(
        f"{NVIDIA_BASE_URL}/embeddings",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
        json={
            "model": EMBED_MODEL,
            "input": texts,
            "input_type": "passage",
            "encoding_format": "float",
        },
        timeout=60,
    )
    if response.status_code != 200:
        raise RuntimeError(f"NVIDIA embeddings {response.status_code}: {response.text[:300]}")

    data = response.json().get("data", [])
    return [item.get("embedding", []) for item in data]


def cosine_similarity(a: list[float], b: list[float]) -> float:
    dot = sum(x * y for x, y in zip(a, b))
    mag_a = math.sqrt(sum(x * x for x in a))
    mag_b = math.sqrt(sum(x * x for x in b))
    if mag_a == 0 or mag_b == 0:
        return 0.0
    return dot / (mag_a * mag_b)


def semantic_split(text: str) -> list[str]:
    """
    Split document text into semantically coherent chunks.

    Algorithm:
      1. Split into sentences.
      2. Embed each sentence (batched).
      3. Where consecutive sentence similarity drops below threshold,
         mark a boundary.
      4. Merge sentences within boundaries, respecting MAX_CHUNK_CHARS.

    Falls back to paragraph splitting if NVIDIA key is not available.
    """
    api_key = os.getenv("NVIDIA_API_KEY_EMBED") or os.getenv("NVIDIA_API_KEY", "")
    if not api_key:
        # Graceful fallback β€” paragraph-based splitting
        return _paragraph_split(text)

    sentences = _split_sentences(text)
    if len(sentences) <= 3:
        return [text.strip()]

    # Embed in batches of 50 (NIM limit awareness)
    embeddings: list[list[float]] = []
    batch_size = 50
    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i + batch_size]
        try:
            embeddings.extend(embed_texts(batch))
        except Exception:
            # If embedding fails mid-way, fall back
            return _paragraph_split(text)

    # Find boundaries: where similarity between adjacent sentences drops
    boundaries: list[int] = [0]
    for i in range(1, len(sentences)):
        sim = cosine_similarity(embeddings[i - 1], embeddings[i])
        if sim < SIMILARITY_THRESHOLD:
            boundaries.append(i)
    boundaries.append(len(sentences))

    # Build chunks from boundary segments
    chunks: list[str] = []
    for b_idx in range(len(boundaries) - 1):
        start = boundaries[b_idx]
        end = boundaries[b_idx + 1]
        segment = " ".join(sentences[start:end]).strip()

        if len(segment) < MIN_CHUNK_CHARS and chunks:
            # Too small β€” merge into previous chunk
            chunks[-1] = chunks[-1] + " " + segment
        elif len(segment) > MAX_CHUNK_CHARS:
            # Too large β€” further split on paragraphs
            chunks.extend(_paragraph_split(segment))
        else:
            chunks.append(segment)

    return [c.strip() for c in chunks if c.strip()]


def inject_context_headers(
    chunks: list[str],
    doc_id: str,
    doc_index: "DocIndex | None" = None,
) -> list[str]:
    """
    Prepend a context header to each chunk so extraction agents
    know where the chunk sits in the document.

    Format:
      [Context | doc: X | chunk: C3 | section: Results | prev_topic: method]
      <original chunk text>

    This eliminates orphaned-fact hallucination: the agent always knows
    what section it's reading and what came before.
    """
    result: list[str] = []

    for i, chunk in enumerate(chunks):
        chunk_id = f"C{i}"

        # Get section name from overview if available
        section = "unknown"
        prev_topic = ""
        if doc_index and doc_index.has_doc(doc_id):
            topic_map = doc_index.get_chunk_topics(doc_id)
            topics = topic_map.get(chunk_id, [])
            if topics:
                section = topics[-1] if topics else "unknown"  # last entry is role
            if i > 0:
                prev_topics = topic_map.get(f"C{i-1}", [])
                prev_topic = prev_topics[0] if prev_topics else ""

        header_parts = [f"doc:{doc_id}", f"chunk:{chunk_id}", f"section:{section}"]
        if prev_topic:
            header_parts.append(f"prev:{prev_topic}")

        header = "[Context | " + " | ".join(header_parts) + "]\n"
        result.append(header + chunk)

    return result


# ── Internal helpers ────────────────────────────────────────────────────────

def _split_sentences(text: str) -> list[str]:
    """Naive but fast sentence splitter."""
    raw = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)
    return [s.strip() for s in raw if s.strip()]


def _paragraph_split(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list[str]:
    """Fallback: split on double newlines then merge to max_chars."""
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    chunks: list[str] = []
    current = ""
    for para in paras:
        if len(current) + len(para) + 2 > max_chars and current:
            chunks.append(current)
            current = para
        else:
            current = (current + "\n\n" + para).strip() if current else para
    if current:
        chunks.append(current)
    return chunks if chunks else [text]