File size: 15,914 Bytes
37a70cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
# src/rag.py
"""
RAG module (updated)
 - FAISS retrieval (sentence-transformers embeddings)
 - Cross-encoder reranker (optional)
 - Prompt template with sentence-aware snippet trimming
 - Generation (OpenAI preferred) or local Flan-T5 fallback
 - Post-processing: concise N sentences, no trailing "..." and no placeholders
 - Deduplication / diversity of contexts by URL
 - Procedural snippet handling (take next sentence if top is a header/list)
"""
from pathlib import Path
import ujson as json
import numpy as np
import textwrap
import os
import faiss
import re

# embeddings & models
from sentence_transformers import SentenceTransformer, CrossEncoder

try:
    import openai
except Exception:
    openai = None

# local generator (transformers)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

ROOT = Path(__file__).parent.parent.resolve()
INDEX_PATH = ROOT.joinpath("faiss_index.bin")
META_PATH = ROOT.joinpath("docs_meta.jsonl")

EMBED_MODEL = "all-MiniLM-L6-v2"          # embeddings model (fast)
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # reranker

# local text generator model (CPU)
LOCAL_GEN_MODEL = "google/flan-t5-small"

# tune here: how many sentences to keep in final answer
MAX_ANSWER_SENTENCES = 2

# diversify: max chunks per url allowed in final top_candidates
MAX_CHUNKS_PER_URL = 2

# lazy-loaded resources
_index = None
_meta = None
_embed_model = None
_cross_encoder = None
_local_generator = None

# -------------------------
# Utilities: sentence helpers & tidy answer
# -------------------------
RE_SENT_SPLIT = re.compile(r'([.!?])\s+')          # split and keep punctuation
RE_ANGLE_PLACEHOLDER = re.compile(r"<[^>]{1,200}>")
RE_DOUBLE_DASH_ID = re.compile(r"\b[a-zA-Z0-9_-]{3,}--\d{3,}\b")
RE_MULTI_DOTS = re.compile(r"\.{2,}")
RE_WHITESPACE = re.compile(r"\s+")

def split_into_sentences(text: str):
    if not text:
        return []
    parts = RE_SENT_SPLIT.split(text)
    sentences = []
    for i in range(0, len(parts), 2):
        chunk = parts[i].strip()
        punct = parts[i+1] if (i+1)<len(parts) else ""
        sentence = (chunk + punct).strip()
        if sentence:
            sentences.append(sentence)
    return sentences

# -------------------------
# NEW: sentence-level extraction helpers with procedural handling
# -------------------------
def get_top_sentences_from_passage(passage_text: str, query: str, embed_model, top_n: int = 1):
    """
    Given a passage text, split into sentences and return top_n sentences by cosine similarity.
    For 'procedural' outputs where the top sentence is a header/menu (short or 'how to'), also
    include the next sentence to provide actionable content.
    """
    if not passage_text:
        return []
    sents = split_into_sentences(passage_text)
    if not sents:
        return []
    if len(sents) <= top_n:
        return sents[:top_n]

    # embed query + sentences
    q_emb = embed_model.encode([query], convert_to_numpy=True)
    s_embs = embed_model.encode(sents, convert_to_numpy=True)

    def norm(x):
        n = np.linalg.norm(x)
        return x / (n + 1e-10)
    qn = norm(q_emb[0])
    sims = [float(np.dot(qn, norm(se))) for se in s_embs]
    idxs = sorted(range(len(sents)), key=lambda i: sims[i], reverse=True)[:top_n]

    # if top sentence looks like a header/menu (very short, contains 'how to' or ends with ':'), also include the next sentence
    out = []
    for idx in idxs:
        out.append(sents[idx])
        # heuristic: if that sentence is short or contains "how to" or looks like heading, add next sentence if exists
        s = sents[idx].lower()
        word_count = len(s.split())
        if (word_count <= 6 or 'how to' in s or s.endswith(':')) and (idx + 1) < len(sents):
            out.append(sents[idx+1])
    # dedupe while preserving order
    seen = set()
    final = []
    for x in out:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final[:top_n]

def extract_fallback_from_contexts(contexts: list, query: str, n_sentences:int = 1) -> str:
    """
    Deterministic fallback: find the single best sentence across contexts and return it verbatim.
    """
    _, _, embed_model = load_index_and_meta()
    best = None
    best_score = -1.0
    import numpy as np
    q_emb = embed_model.encode([query], convert_to_numpy=True)[0]
    def norm(x): 
        n=np.linalg.norm(x); return x/(n+1e-10)
    qn = norm(q_emb)

    for c in contexts:
        sents = split_into_sentences(c.get("text",""))
        if not sents:
            continue
        s_embs = embed_model.encode(sents, convert_to_numpy=True)
        for s, se in zip(sents, s_embs):
            sc = float(np.dot(qn, norm(se)))
            if sc > best_score:
                best_score = sc
                best = s
    if not best:
        return ""
    return best

def tidy_answer(ans: str, max_sentences: int = MAX_ANSWER_SENTENCES) -> str:
    if not ans:
        return ans
    a = ans
    a = RE_ANGLE_PLACEHOLDER.sub(" ", a)
    a = RE_DOUBLE_DASH_ID.sub(" ", a)
    a = a.replace("…", ". ")
    a = RE_MULTI_DOTS.sub(". ", a)
    a = RE_WHITESPACE.sub(" ", a).strip()
    sents = split_into_sentences(a)
    if not sents:
        snippet = a[:300].strip()
        if not snippet.endswith("."):
            snippet = snippet.rstrip(" .,") + "."
        return snippet
    take = sents[:max_sentences]
    out = " ".join(take).strip()
    if out and out[-1] not in ".!?":
        out = out.rstrip(" .,") + "."
    return out

# -------------------------
# Loading helpers
# -------------------------
def load_index_and_meta():
    global _index, _meta, _embed_model
    if _index is None:
        if not INDEX_PATH.exists():
            raise FileNotFoundError(f"FAISS index not found at {INDEX_PATH}. Run src/indexer.py first.")
        _index = faiss.read_index(str(INDEX_PATH))
    if _meta is None:
        if not META_PATH.exists():
            raise FileNotFoundError(f"Meta file not found at {META_PATH}. Run src/indexer.py first.")
        _meta = [json.loads(l) for l in META_PATH.read_text(encoding="utf-8").splitlines()]
    if _embed_model is None:
        _embed_model = SentenceTransformer(EMBED_MODEL)
    return _index, _meta, _embed_model

def get_cross_encoder():
    global _cross_encoder
    if _cross_encoder is None:
        _cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
    return _cross_encoder

def get_local_generator():
    global _local_generator
    if _local_generator is None:
        tok = AutoTokenizer.from_pretrained(LOCAL_GEN_MODEL)
        model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_GEN_MODEL)
        _local_generator = pipeline("text2text-generation", model=model, tokenizer=tok, device=-1)
    return _local_generator

# -------------------------
# Embedding + retrieval
# -------------------------
def embed_query(q: str):
    _, _, em = load_index_and_meta()
    emb = em.encode(q)
    emb = np.asarray(emb, dtype="float32")
    if emb.ndim == 1:
        emb = emb.reshape(1, -1)
    faiss.normalize_L2(emb)
    return emb

def retrieve_candidates(query: str, top_k: int = 50):
    index, meta, _ = load_index_and_meta()
    emb = embed_query(query)
    D, I = index.search(emb, top_k)
    results = []
    if len(I) == 0:
        return results
    for score, idx in zip(D[0], I[0]):
        if idx < 0:
            continue
        m = meta[idx]
        results.append({"score": float(score), "url": m["url"], "title": m.get("title",""), "text": m["text"], "id": idx})
    return results

# -------------------------
# Reranking
# -------------------------
def rerank_with_cross(query: str, candidates: list, top_n: int = 5):
    if not candidates:
        return []
    cross = get_cross_encoder()
    inputs = [(query, c["text"]) for c in candidates]
    try:
        scores = cross.predict(inputs)
    except Exception as e:
        # fallback: if cross encoder fails, return top_n by original score
        candidates.sort(key=lambda x: x.get("score",0), reverse=True)
        return candidates[:top_n]
    for c, s in zip(candidates, scores):
        c["rerank_score"] = float(s)
    candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
    return candidates[:top_n]

# -------------------------
# Snippet trimming helpers
# -------------------------
def trim_snippet_to_sentence(snippet: str, max_chars: int = 800) -> str:
    if not snippet:
        return snippet
    s = snippet.replace("\n", " ").strip()
    if len(s) <= max_chars:
        return s
    head = s[:max_chars]
    last_dot = max(head.rfind("."), head.rfind("!"), head.rfind("?"))
    if last_dot and last_dot > int(max_chars * 0.4):
        return head[:last_dot+1].strip()
    cut = head.rsplit(" ", 1)[0]
    return cut.strip()

# -------------------------
# Prompt template (stronger)
# -------------------------
def build_prompt(query: str, contexts: list, sentences_per_context: int = 1):
    """
    Build a prompt that is explicit about producing an ACTIONABLE, concise answer.
    We include one short sentence per source (selected by semantic similarity).
    """
    _, _, embed_model = load_index_and_meta()
    sources_block = []
    for i, c in enumerate(contexts, start=1):
        passage = c.get("text", "").strip()
        best_sents = get_top_sentences_from_passage(passage, query, embed_model, top_n=sentences_per_context)
        snippet = " ".join(best_sents)
        snippet = trim_snippet_to_sentence(snippet, max_chars=500)
        snippet = snippet.rstrip(" .") + "." if snippet and snippet[-1] not in ".!?" else snippet
        sources_block.append(f"[SRC_{i}] URL: {c.get('url')}\n[SRC_{i}] TEXT: {snippet}")

    sources_text = "\n\n".join(sources_block)

    prompt = textwrap.dedent(f"""
    Use only the following snippets to produce a concise, ACTIONABLE answer (1-2 short sentences) that directly answers the question. 
    For "how-to" queries, produce concrete steps or exact fields to set where possible. Do NOT invent facts or add information not present in snippets. 
    If snippets do not contain an answer, reply: "I don't know — please consult the documentation." Then list the Source URLs used.

    {sources_text}

    Question:
    {query}

    Answer (be concise and then list Sources used as URLs):
    """).strip()
    return prompt

# -------------------------
# Generation (OpenAI or local) with fallback formatting
# -------------------------
def generate_answer_with_context(question: str, contexts: list, use_openai: bool = False):
    prompt = build_prompt(question, contexts, sentences_per_context=1)

    # Option A: OpenAI (preferred)
    if use_openai and openai is not None and os.environ.get("OPENAI_API_KEY"):
        try:
            resp = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role":"system","content":"You are a strict assistant. Use ONLY the provided documentation snippets to answer. Do not hallucinate."},
                    {"role":"user","content": prompt}
                ],
                temperature=0.0,
                max_tokens=200,
                stop=None
            )
            raw_answer = resp["choices"][0]["message"]["content"].strip()
            answer = tidy_answer(raw_answer, max_sentences=MAX_ANSWER_SENTENCES)
            # fallback if model returned unhelpful phrasing
            if answer.lower().startswith("i'm") or "find the source" in answer.lower() or answer.lower().startswith("see"):
                fallback = extract_fallback_from_contexts(contexts, question)
                fallback = fallback.strip()
                if fallback and not fallback.endswith(('.', '!', '?')):
                    fallback = fallback + '.'
                if 'okta' in fallback.lower() or 'authenticator' in fallback.lower():
                    fallback = "Enable Okta SAML SSO: " + fallback
                return tidy_answer(fallback, max_sentences=MAX_ANSWER_SENTENCES), [c["url"] for c in contexts]
            used_urls = [c["url"] for c in contexts if c["url"] in raw_answer]
            if not used_urls:
                used_urls = [c["url"] for c in contexts]
            return answer, used_urls
        except Exception as e:
            print("OpenAI generation failed:", e)

    # Option B: Local generator fallback
    gen = get_local_generator()
    gen_kwargs = {
        "max_length": 200,
        "num_beams": 4,
        "do_sample": False,
        "no_repeat_ngram_size": 3,
        "early_stopping": True
    }
    out = gen(prompt, **gen_kwargs)
    raw_answer = out[0].get("generated_text","").strip()
    answer = tidy_answer(raw_answer, max_sentences=MAX_ANSWER_SENTENCES)
    if answer.lower().startswith("i'm") or "find the source" in answer.lower() or answer.lower().startswith("see"):
        fallback = extract_fallback_from_contexts(contexts, question)
        fallback = fallback.strip()
        if fallback and not fallback.endswith(('.', '!', '?')):
            fallback = fallback + '.'
        if 'okta' in fallback.lower() or 'authenticator' in fallback.lower():
            fallback = "Enable Okta SAML SSO: " + fallback
        return tidy_answer(fallback, max_sentences=MAX_ANSWER_SENTENCES), [c["url"] for c in contexts]
    return answer, [c["url"] for c in contexts]

# -------------------------
# Top-level handler with dedup/diversify
# -------------------------
def handle_rag_query(query: str, top_k: int = 5, use_openai: bool = False, rerank_candidates: int = 50):
    candidates = retrieve_candidates(query, top_k=rerank_candidates)
    if not candidates:
        return {"answer": "No relevant documentation found.", "sources": [], "retrieved": []}

    try:
        top_candidates = rerank_with_cross(query, candidates, top_n=rerank_candidates)
    except Exception as e:
        print("Reranker failed, falling back to FAISS order:", e)
        top_candidates = candidates[:rerank_candidates]

    # Now pick final top_k diversified by URL:
    # strategy: prefer at most MAX_CHUNKS_PER_URL per url; prefer higher rerank_score
    # first, group by URL preserving order
    url_counts = {}
    diversified = []
    for c in top_candidates:
        url = c.get("url")
        cnt = url_counts.get(url, 0)
        if cnt < MAX_CHUNKS_PER_URL:
            diversified.append(c)
            url_counts[url] = cnt + 1
        # stop early if we have enough
        if len(diversified) >= max(top_k, len(top_candidates)):
            break

    # final trimming: ensure at most one chunk per URL until we fill top_k
    seen_urls = set()
    unique_candidates = []
    for c in diversified:
        u = c.get("url")
        if u in seen_urls:
            continue
        unique_candidates.append(c)
        seen_urls.add(u)
        if len(unique_candidates) >= top_k:
            break
    # if we don't have enough unique URLs, allow second chunks (already in diversified)
    if len(unique_candidates) < top_k:
        # fill from diversified preserving order but skipping already selected items
        for c in diversified:
            if c in unique_candidates:
                continue
            unique_candidates.append(c)
            if len(unique_candidates) >= top_k:
                break
    final_candidates = unique_candidates[:top_k]

    # generate answer using final candidates
    answer, urls = generate_answer_with_context(query, final_candidates, use_openai=use_openai)

    return {"answer": answer, "sources": urls, "retrieved": final_candidates}

# small test if run as script
if __name__ == "__main__":
    q = "How do I configure SAML SSO with Okta?"
    print("Running test query:", q)
    res = handle_rag_query(q, top_k=3, use_openai=False)
    print("ANSWER:\n", res["answer"])
    print("SOURCES:\n", res["sources"])
    for r in res["retrieved"][:3]:
        print("----\n", r["url"], "\n", r["text"][:300])