# services/kb_creation.py import os import re import pickle from typing import List, Dict, Any, Tuple, Optional from docx import Document from sentence_transformers import SentenceTransformer import chromadb # ------------------------------ ChromaDB setup ------------------------------ CHROMA_PATH = os.path.join(os.getcwd(), "chroma_db") client = chromadb.PersistentClient(path=CHROMA_PATH) collection = client.get_or_create_collection(name="knowledge_base") # ------------------------------ Embedding model ------------------------------ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # ------------------------------ BM25 (lightweight) ------------------------------ BM25_INDEX_FILE = os.path.join(CHROMA_PATH, "bm25_index.pkl") bm25_docs: List[Dict[str, Any]] = [] bm25_inverted: Dict[str, List[int]] = {} bm25_df: Dict[str, int] = {} bm25_avgdl: float = 0.0 bm25_ready: bool = False BM25_K1 = 1.5 BM25_B = 0.75 # ------------------------------ Utilities ------------------------------ def _tokenize(text: str) -> List[str]: if not text: return [] text = text.lower() return re.findall(r"[a-z0-9]+", text) def _normalize_query(q: str) -> str: q = (q or "").strip().lower() q = re.sub(r"[^\w\s]", " ", q) q = re.sub(r"\s+", " ", q).strip() return q def _tokenize_meta_value(val: Optional[str]) -> List[str]: return _tokenize(val or "") # ------------------------------ DOCX parsing & chunking ------------------------------ BULLET_RE = re.compile(r"^\s*(?:[\-\*\u2022]|\d+[.)])\s+", re.IGNORECASE) def _split_by_sections(doc: Document) -> List[Tuple[str, List[str]]]: sections: List[Tuple[str, List[str]]] = [] current_title = None current_paras: List[str] = [] for para in doc.paragraphs: text = (para.text or "").strip() style_name = (para.style.name if para.style else "") or "" is_heading = bool(re.match(r"Heading\s*\d+", style_name, flags=re.IGNORECASE)) if is_heading and text: if current_title or current_paras: sections.append((current_title or "Untitled Section", current_paras)) current_title = text current_paras = [] else: if text: current_paras.append(text) if current_title or current_paras: sections.append((current_title or "Untitled Section", current_paras)) if not sections: all_text = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] sections = [("Document", all_text)] return sections def _paragraphs_to_lines(paragraphs: List[str]) -> List[str]: """Preserve bullets/numbered list lines; split long paragraphs by sentence boundaries.""" lines: List[str] = [] for p in (paragraphs or []): p = (p or "").strip() if not p: continue if BULLET_RE.match(p): lines.append(p) continue parts = [s.strip() for s in re.split(r"(?<=[.!?])\s+", p) if s.strip()] lines.extend(parts) return lines def _chunk_text_with_context(doc_title: str, section_title: str, paragraphs: List[str], max_words: int = 160) -> List[str]: """Smaller chunks (~160 words), bullet-aware.""" lines = _paragraphs_to_lines(paragraphs) chunks: List[str] = [] current: List[str] = [] current_len = 0 for ln in lines: w = ln.split() if current_len + len(w) > max_words or (BULLET_RE.match(ln) and current): chunk = " ".join(current).strip() if chunk: chunks.append(chunk) current = [ln] current_len = len(w) else: current.append(ln) current_len += len(w) if current: chunk = " ".join(current).strip() if chunk: chunks.append(chunk) if not chunks: body = " ".join(lines).strip() if body: chunks = [body] return chunks # ------------------------------ Intent & Module tagging ------------------------------ SECTION_STEPS_HINTS = ["process steps", "procedure", "how to", "workflow", "instructions", "steps"] SECTION_ERRORS_HINTS = ["common errors", "resolution", "troubleshooting", "known issues", "common issues", "escalation", "escalation path", "permissions", "access"] PERMISSION_TERMS = [ "permission", "permissions", "access", "access right", "authorization", "authorisation", "role", "role access", "role mapping", "security", "security profile", "privilege", "insufficient", "not allowed", "not authorized", "denied", "restrict" ] ERROR_TERMS = ["error", "issue", "fail", "failure", "not working", "cannot", "can't", "mismatch", "locked", "wrong", "denied"] STEP_VERBS = ["navigate", "select", "scan", "verify", "confirm", "print", "move", "complete", "click", "open", "choose", "enter", "update", "save", "delete", "create", "attach", "assign"] MODULE_VOCAB = { "receiving": [ "receive", "receiving", "inbound receiving", "inbound", "goods receipt", "grn", "asn receiving", "unload", "check-in", "dock check-in" ], "appointments": [ "appointment", "appointments", "schedule", "scheduling", "slot", "dock door", "appointment creation", "appointment details" ], "picking": ["pick", "picking", "pick release", "wave", "allocation"], "putaway": ["putaway", "staging", "put away", "location assignment"], "shipping": ["shipping", "ship confirm", "outbound", "load", "trailer"], "inventory": ["inventory", "adjustment", "cycle count", "count", "uom"], "replenishment": ["replenishment", "replenish"], } def _infer_intent_tag(section_title: str) -> str: st = (section_title or "").lower() if any(k in st for k in SECTION_STEPS_HINTS): return "steps" if any(k in st for k in SECTION_ERRORS_HINTS): return "errors" if "pre" in st and "requisite" in st: return "prereqs" if any(k in st for k in ["purpose", "overview", "introduction"]): return "purpose" if any(k in st for k in ["inbound receiving", "receiving", "goods receipt", "grn"]): return "steps" if any(k in st for k in ["appointment", "appointments", "schedule", "scheduling"]): return "steps" return "neutral" def _derive_semantic_intent_from_text(text: str) -> Tuple[str, List[str]]: t = (text or "").lower() tags: List[str] = [] intent = "neutral" if any(term in t for term in PERMISSION_TERMS): intent = "errors"; tags.append("permissions") if "role" in t: tags.append("role_access") if "security" in t: tags.append("security") if intent == "neutral" and any(term in t for term in ERROR_TERMS): intent = "errors"; tags.append("errors") if intent == "neutral" and any(v in t for v in STEP_VERBS): intent = "steps"; tags.append("procedure") return intent, list(set(tags)) def _derive_module_tags(text: str, filename: str, section_title: str) -> List[str]: tokens = " ".join([filename or "", section_title or "", text or ""]).lower() found = [] for mod, syns in MODULE_VOCAB.items(): if any(s in tokens for s in syns): found.append(mod) if not found: if "inventory" in tokens or "adjust" in tokens or "uom" in tokens or "cycle" in tokens: found = ["inventory"] elif "receive" in tokens or "inbound" in tokens or "goods receipt" in tokens or "grn" in tokens: found = ["receiving"] elif "appointment" in tokens or "schedule" in tokens or "dock" in tokens: found = ["appointments"] return list(sorted(set(found))) # ------------------------------ Ingestion ------------------------------ def ingest_documents(folder_path: str) -> None: print(f"[KB] Checking folder: {folder_path}") files = [f for f in os.listdir(folder_path) if f.lower().endswith('.docx')] print(f"[KB] Found {len(files)} Word files: {files}") if not files: print("[KB] WARNING: No .docx files found. Please check the folder path.") return global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready bm25_docs, bm25_inverted, bm25_df = [], {}, {} bm25_avgdl, bm25_ready = 0.0, False for file in files: file_path = os.path.join(folder_path, file) doc_title = os.path.splitext(file)[0] doc = Document(file_path) sections = _split_by_sections(doc) total_chunks = 0 for s_idx, (section_title, paras) in enumerate(sections): chunks = _chunk_text_with_context(doc_title, section_title, paras, max_words=160) total_chunks += len(chunks) base_intent = _infer_intent_tag(section_title) for c_idx, chunk in enumerate(chunks): derived_intent, topic_tags = _derive_semantic_intent_from_text(chunk) final_intent = base_intent if derived_intent == "errors": final_intent = "errors" elif base_intent == "neutral" and derived_intent in ("steps", "prereqs"): final_intent = derived_intent module_tags = _derive_module_tags(chunk, file, section_title) embedding = model.encode(chunk).tolist() doc_id = f"{file}:{s_idx}:{c_idx}" meta = { "filename": file, "section": section_title, "chunk_index": c_idx, "title": doc_title, "collection": "SOP", "intent_tag": final_intent, "topic_tags": ", ".join(topic_tags) if topic_tags else "", "module_tags": ", ".join(module_tags) if module_tags else "", } try: collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta]) except Exception: try: collection.delete(ids=[doc_id]) collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta]) except Exception as e2: print(f"[KB] ERROR: Upsert failed for {doc_id}: {e2}") tokens = _tokenize(chunk) tf: Dict[str, int] = {} for tkn in tokens: tf[tkn] = tf.get(tkn, 0) + 1 idx = len(bm25_docs) bm25_docs.append({ "id": doc_id, "text": chunk, "tokens": tokens, "tf": tf, "length": len(tokens), "meta": meta, }) seen = set() for term in tf.keys(): bm25_inverted.setdefault(term, []).append(idx) if term not in seen: bm25_df[term] = bm25_df.get(term, 0) + 1 seen.add(term) print(f"[KB] Ingested {file} → {total_chunks} chunks") N = len(bm25_docs) if N > 0: bm25_avgdl = sum(d["length"] for d in bm25_docs) / float(N) bm25_ready = True payload = { "bm25_docs": bm25_docs, "bm25_inverted": bm25_inverted, "bm25_df": bm25_df, "bm25_avgdl": bm25_avgdl, "BM25_K1": BM25_K1, "BM25_B": BM25_B, } os.makedirs(CHROMA_PATH, exist_ok=True) with open(BM25_INDEX_FILE, "wb") as f: pickle.dump(payload, f) print(f"[KB] BM25 index saved: {BM25_INDEX_FILE}") print(f"[KB] Documents ingested. Total entries in Chroma: {collection.count()}") # ------------------------------ BM25 load ------------------------------ def _load_bm25_index() -> None: global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready if not os.path.exists(BM25_INDEX_FILE): return try: with open(BM25_INDEX_FILE, "rb") as f: payload = pickle.load(f) bm25_docs = payload.get("bm25_docs", []) bm25_inverted = payload.get("bm25_inverted", {}) bm25_df = payload.get("bm25_df", {}) bm25_avgdl = payload.get("bm25_avgdl", 0.0) bm25_ready = len(bm25_docs) > 0 if bm25_ready: print(f"[KB] BM25 index loaded: {BM25_INDEX_FILE} (docs={len(bm25_docs)})") except Exception as e: print(f"[KB] WARNING: Could not load BM25 index: {e}") _load_bm25_index() # ------------------------------ BM25 search ------------------------------ def _bm25_score_for_doc(query_terms: List[str], doc_idx: int) -> float: if not bm25_ready or doc_idx < 0 or doc_idx >= len(bm25_docs): return 0.0 doc = bm25_docs[doc_idx] score = 0.0 dl = doc["length"] or 1 for term in query_terms: df = bm25_df.get(term, 0) if df == 0: continue tf = doc["tf"].get(term, 0) if tf == 0: continue N = len(bm25_docs) idf_ratio = ((N - df + 0.5) / (df + 0.5)) try: import math idf = math.log(idf_ratio + 1.0) except Exception: idf = 1.0 denom = tf + BM25_K1 * (1 - BM25_B + BM25_B * (dl / (bm25_avgdl or 1.0))) score += idf * (((tf * (BM25_K1 + 1)) / (denom or 1.0))) return score def bm25_search(query: str, top_k: int = 50) -> List[Tuple[int, float]]: if not bm25_ready: return [] norm = _normalize_query(query) q_terms = _tokenize(norm) if not q_terms: return [] candidates = set() for t in q_terms: for idx in bm25_inverted.get(t, []): candidates.add(idx) if not candidates: candidates = set(range(len(bm25_docs))) scored = [] for idx in candidates: s = _bm25_score_for_doc(q_terms, idx) if s > 0: scored.append((idx, s)) scored.sort(key=lambda x: x[1], reverse=True) return scored[:top_k] # ------------------------------ Semantic-only ------------------------------ def search_knowledge_base(query: str, top_k: int = 10) -> dict: query_embedding = model.encode(query).tolist() res = collection.query( query_embeddings=[query_embedding], n_results=top_k, include=['documents', 'metadatas', 'distances'] # no 'ids' ) documents = (res.get("documents", [[]]) or [[]])[0] metadatas = (res.get("metadatas", [[]]) or [[]])[0] distances = (res.get("distances", [[]]) or [[]])[0] # Synthesize IDs from metadata (filename:section:chunk_index) ids: List[str] = [] if documents: synthesized = [] for i, m in enumerate(metadatas): fn = (m or {}).get("filename", "unknown") sec = (m or {}).get("section", "section") idx = (m or {}).get("chunk_index", i) synthesized.append(f"{fn}:{sec}:{idx}") ids = synthesized print(f"[KB] search → {len(documents)} docs (top_k={top_k}); first distance: {distances[0] if distances else 'n/a'}; ids synthesized={len(ids)}") return { "documents": documents, "metadatas": metadatas, "distances": distances, "ids": ids, } # ------------------------------ Hybrid search (generic + intent-aware) ------------------------------ ACTION_SYNONYMS = { "create": ["create", "creation", "add", "new", "generate"], "update": ["update", "modify", "change", "edit"], "delete": ["delete", "remove"], "navigate": ["navigate", "go to", "open"], } ERROR_INTENT_TERMS = [ "error", "issue", "fail", "not working", "resolution", "fix", "permission", "permissions", "access", "no access", "authorization", "authorisation", "role", "role mapping", "not authorized", "permission denied", "insufficient privileges", "escalation", "escalation path", "access right", "mismatch", "locked", "wrong" ] def _detect_user_intent(query: str) -> str: q = (query or '').lower() if any(k in q for k in ERROR_INTENT_TERMS): return 'errors' # Treat 'next step' style queries as steps intent NEXT_TERMS = ('next step','what next','whats next','then what','following step','continue','after','proceed') if any(k in q for k in NEXT_TERMS): return 'steps' if any(k in q for k in ['steps','procedure','how to','navigate','process','do','perform']): return 'steps' if any(k in q for k in ['pre-requisite','prerequisites','requirement','requirements']): return 'prereqs' if any(k in q for k in ['purpose','overview','introduction']): return 'purpose' return 'neutral' def _extract_actions(query: str) -> List[str]: q = (query or "").lower() found = [] ACTION_SYNONYMS = { "create": ("create", "creation", "add", "new", "generate", "setup", "set up", "register"), "update": ("update", "modify", "change", "edit", "amend"), "delete": ("delete", "remove", "cancel", "void"), "navigate": ("navigate", "go to", "open"), } # direct synonyms for act, syns in ACTION_SYNONYMS.items(): if any(s in q for s in syns): found.append(act) # extra cues if "steps for" in q or "procedure for" in q or "how to" in q: # pick the action that follows these cues for act, syns in ACTION_SYNONYMS.items(): if any(("steps for " + s) in q for s in syns) or any(("procedure for " + s) in q for s in syns): found.append(act) return sorted(set(found)) or [] def _extract_modules_from_query(query: str) -> List[str]: q = (query or "").lower() found = [] for mod, syns in MODULE_VOCAB.items(): if any(s in q for s in syns): found.append(mod) return sorted(set(found)) def _action_weight(text: str, actions: List[str]) -> float: if not actions: return 0.0 t = (text or "").lower() score = 0.0 for act in actions: for syn in ACTION_SYNONYMS.get(act, [act]): if syn in t: score += 1.0 conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": []} for act in actions: for bad in conflicts.get(act, []): for syn in ACTION_SYNONYMS.get(bad, [bad]): if syn in t: score -= 0.8 return score def _module_weight(meta: Dict[str, Any], user_modules: List[str]) -> float: if not user_modules: return 0.0 raw = (meta or {}).get("module_tags", "") or "" doc_modules = [m.strip() for m in raw.split(",") if m.strip()] if isinstance(raw, str) else (raw or []) overlap = len(set(user_modules) & set(doc_modules)) if overlap == 0: return -0.8 return 0.7 * overlap def _intent_weight(meta: dict, user_intent: str) -> float: tag = (meta or {}).get("intent_tag", "neutral") if user_intent == "neutral": return 0.0 if tag == user_intent: return 1.0 if tag in ["purpose", "prereqs"] and user_intent in ["steps", "errors"]: return -0.6 st = ((meta or {}).get("section", "") or "").lower() topics = (meta or {}).get("topic_tags", "") or "" topic_list = [t.strip() for t in topics.split(",") if t.strip()] if user_intent == "errors" and ( any(k in st for k in ["common errors", "known issues", "common issues", "errors", "escalation", "permissions", "access"]) or ("permissions" in topic_list) ): return 1.10 if user_intent == "steps" and any(k in st for k in ["process steps", "procedure", "instructions", "workflow"]): return 0.75 return -0.2 def _meta_overlap(meta: Dict[str, Any], q_terms: List[str]) -> float: fn_tokens = _tokenize_meta_value(meta.get("filename")) title_tokens = _tokenize_meta_value(meta.get("title")) section_tokens = _tokenize_meta_value(meta.get("section")) topic_tokens = _tokenize_meta_value((meta.get("topic_tags") or "")) module_tokens = _tokenize_meta_value((meta.get("module_tags") or "")) meta_tokens = set(fn_tokens + title_tokens + section_tokens + topic_tokens + module_tokens) if not meta_tokens or not q_terms: return 0.0 qset = set(q_terms) inter = len(meta_tokens & qset) return inter / max(1, len(qset)) def _make_ngrams(tokens: List[str], n: int) -> List[str]: return [" ".join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)] def _phrase_boost_score(text: str, q_terms: List[str]) -> float: if not text or not q_terms: return 0.0 low = (text or "").lower() bigrams = _make_ngrams(q_terms, 2) trigrams = _make_ngrams(q_terms, 3) score = 0.0 for bg in bigrams: if bg and bg in low: score += 0.40 for tg in trigrams: if tg and tg in low: score += 0.70 return min(score, 2.0) def _literal_query_match_boost(text: str, query_norm: str) -> float: t = (text or "").lower() q = (query_norm or "").lower() boost = 0.0 if q and q in t: boost += 0.8 toks = [tok for tok in q.split() if len(tok) > 2] bigrams = _make_ngrams(toks, 2) for bg in bigrams: if bg in t: boost += 0.8 break return min(boost, 1.6) def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict: """ Hybrid retrieval (embeddings + BM25) with intent-, action-, module-, and phrase-aware reranking. Returns top items plus doc-level prior and intent for downstream formatting. """ norm_query = _normalize_query(query) q_terms = _tokenize(norm_query) user_intent = _detect_user_intent(query) actions = _extract_actions(query) user_modules = _extract_modules_from_query(query) # semantic (embeddings) search via Chroma sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 40)) sem_docs = sem_res.get("documents", []) sem_metas = sem_res.get("metadatas", []) sem_dists = sem_res.get("distances", []) sem_ids = sem_res.get("ids", []) def dist_to_sim(d: Optional[float]) -> float: if d is None: return 0.0 try: return 1.0 / (1.0 + float(d)) except Exception: return 0.0 sem_sims = [dist_to_sim(d) for d in sem_dists] # BM25 search bm25_hits = bm25_search(norm_query, top_k=max(80, top_k * 6)) bm25_max = max([s for _, s in bm25_hits], default=1.0) bm25_norm_pairs = [(idx, (score / bm25_max) if bm25_max > 0 else 0.0) for idx, score in bm25_hits] bm25_id_to_norm, bm25_id_to_text, bm25_id_to_meta = {}, {}, {} for idx, nscore in bm25_norm_pairs: d = bm25_docs[idx] bm25_id_to_norm[d["id"]] = nscore bm25_id_to_text[d["id"]] = d["text"] bm25_id_to_meta[d["id"]] = d["meta"] # union of candidate IDs (semantic + bm25) union_ids = set(sem_ids) | set(bm25_id_to_norm.keys()) # weights gamma = 0.30 # meta overlap delta = 0.55 # intent boost epsilon = 0.30 # action weight zeta = 0.65 # module weight eta = 0.50 # phrase-level boost theta = 0.00 # optional heading alignment bonus not used iota = 0.60 # literal query match boost combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = [] for cid in union_ids: # pick semantic fields if present; fallback to bm25 if cid in sem_ids: pos = sem_ids.index(cid) sem_sim = sem_sims[pos] if pos < len(sem_sims) else 0.0 sem_dist = sem_dists[pos] if pos < len(sem_dists) else None sem_text = sem_docs[pos] if pos < len(sem_docs) else "" sem_meta = sem_metas[pos] if pos < len(sem_metas) else {} else: sem_sim, sem_dist, sem_text, sem_meta = 0.0, None, "", {} bm25_sim = bm25_id_to_norm.get(cid, 0.0) bm25_text = bm25_id_to_text.get(cid, "") bm25_meta = bm25_id_to_meta.get(cid, {}) text = sem_text if sem_text else bm25_text meta = sem_meta if sem_meta else bm25_meta m_overlap = _meta_overlap(meta, q_terms) intent_boost = _intent_weight(meta, user_intent) act_wt = _action_weight(text, actions) mod_wt = _module_weight(meta, user_modules) phrase_wt = _phrase_boost_score(text, q_terms) literal_wt = _literal_query_match_boost(text, norm_query) final_score = ( alpha * sem_sim + beta * bm25_sim + gamma * m_overlap + delta * intent_boost + epsilon * act_wt + zeta * mod_wt + eta * phrase_wt + theta * 0.0 + iota * literal_wt ) combined_records_ext.append( (cid, final_score, (sem_dist if sem_dist is not None else 999.0), text, meta, m_overlap, intent_boost, act_wt, mod_wt, phrase_wt, 0.0, literal_wt) ) # exact-match rerank for errors (push lines containing query phrases) if user_intent == "errors": exact_hits = [] toks = [tok for tok in norm_query.split() if len(tok) > 2] bigrams = _make_ngrams(toks, 2) for rec in combined_records_ext: text_lower = (rec[3] or "").lower() if norm_query and norm_query in text_lower: exact_hits.append(rec) continue if any(bg in text_lower for bg in bigrams): exact_hits.append(rec) if exact_hits: rest = [r for r in combined_records_ext if r not in exact_hits] exact_hits.sort(key=lambda x: x[1], reverse=True) rest.sort(key=lambda x: x[1], reverse=True) combined_records_ext = exact_hits + rest # doc-level prior: prefer docs with more aligned chunks from collections import defaultdict as _dd doc_groups: Dict[str, List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]] = _dd(list) for rec in combined_records_ext: meta = rec[4] or {} fn = meta.get("filename", "unknown") doc_groups[fn].append(rec) def doc_prior(recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]) -> float: total_score = sum(r[1] for r in recs) total_overlap = sum(r[5] for r in recs) total_intent = sum(max(0.0, r[6]) for r in recs) total_action = sum(max(0.0, r[7]) for r in recs) total_module = sum(r[8] for r in recs) total_phrase = sum(r[9] for r in recs) total_literal = sum(r[11] for r in recs) total_penalty = sum(min(0.0, r[6]) for r in recs) + sum(min(0.0, r[7]) for r in recs) errors_section_bonus = 0.0 if any("error" in ((r[4] or {}).get("section", "")).lower() or "known issues" in ((r[4] or {}).get("section", "")).lower() or "common issues" in ((r[4] or {}).get("section", "")).lower() for r in recs): errors_section_bonus = 0.5 return ( total_score + 0.4 * total_overlap + 0.7 * total_intent + 0.5 * total_action + 0.8 * total_module + 0.6 * total_phrase + 0.7 * total_literal + errors_section_bonus + 0.3 * total_penalty ) best_doc, best_doc_prior = None, -1.0 for fn, recs in doc_groups.items(): p = doc_prior(recs) if p > best_doc_prior: best_doc_prior, best_doc = p, fn best_recs = sorted(doc_groups.get(best_doc, []), key=lambda x: x[1], reverse=True) other_recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = [] for fn, recs in doc_groups.items(): if fn == best_doc: continue other_recs.extend(recs) other_recs.sort(key=lambda x: x[1], reverse=True) reordered = best_recs + other_recs top = reordered[:top_k] documents = [t[3] for t in top] metadatas = [t[4] for t in top] distances = [t[2] for t in top] ids = [t[0] for t in top] combined_scores = [t[1] for t in top] return { "documents": documents, "metadatas": metadatas, "distances": distances, "ids": ids, "combined_scores": combined_scores, "best_doc": best_doc, "best_doc_prior": best_doc_prior, "user_intent": user_intent, "actions": actions, } # ------------------------------ Section fetch helpers ------------------------------ def get_section_text(filename: str, section: str) -> str: texts: List[str] = [] for d in bm25_docs: m = d.get("meta", {}) if m.get("filename") == filename and m.get("section") == section: t = (d.get("text") or "").strip() if t: texts.append(t) return "\n\n".join(texts).strip() def get_best_steps_section_text(filename: str) -> str: texts: List[str] = [] for d in bm25_docs: m = d.get("meta", {}) if m.get("filename") == filename and (m.get("intent_tag") == "steps"): t = (d.get("text") or "").strip() if t: texts.append(t) return "\n\n".join(texts).strip() def get_best_errors_section_text(filename: str) -> str: texts: List[str] = [] for d in bm25_docs: m = d.get("meta", {}) sec = (m.get("section") or "").lower() topics = (m.get("topic_tags") or "") topic_list = [t.strip() for t in topics.split(",") if t.strip()] if m.get("filename") == filename and ( m.get("intent_tag") == "errors" or "error" in sec or "escalation" in sec or "permission" in sec or "access" in sec or "known issues" in sec or "common issues" in sec or "errors" in sec or ("permissions" in topic_list) ): t = (d.get("text") or "").strip() if t: texts.append(t) return "\n\n".join(texts).strip() def get_escalation_text(filename: str) -> str: """ Return concatenated text from any 'Escalation' section in the given SOP file. Works across future SOPs—only relies on the heading name containing 'escalation'. """ texts: List[str] = [] for d in bm25_docs: m = d.get("meta", {}) if m.get("filename") == filename: sec = (m.get("section") or "").lower() if "escalation" in sec: t = (d.get("text") or "").strip() if t: texts.append(t) return "\n\n".join(texts).strip() # ------------------------------ Admin helpers ------------------------------ def get_kb_runtime_info() -> Dict[str, Any]: return { "chroma_path": CHROMA_PATH, "chroma_exists": os.path.isdir(CHROMA_PATH), "bm25_index_file": BM25_INDEX_FILE, "bm25_index_exists": os.path.isfile(BM25_INDEX_FILE), "collection_count": collection.count(), "bm25_ready": bm25_ready, } def reset_kb(folder_path: str) -> Dict[str, Any]: result = {"status": "OK", "message": "KB reset and re-ingested"} try: try: client.delete_collection(name="knowledge_base") except Exception: pass global collection collection = client.get_or_create_collection(name="knowledge_base") try: if os.path.isfile(BM25_INDEX_FILE): os.remove(BM25_INDEX_FILE) except Exception as e: result.setdefault("warnings", []).append(f"bm25 index delete: {e}") os.makedirs(CHROMA_PATH, exist_ok=True) ingest_documents(folder_path) result["info"] = get_kb_runtime_info() return result except Exception as e: return {"status": "ERROR", "error": f"{e}", "info": get_kb_runtime_info()}