Spaces:
Sleeping
Sleeping
| # services/kb_creation.py | |
| import os | |
| import re | |
| import pickle | |
| from typing import List, Dict, Any, Tuple, Optional | |
| from docx import Document | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| # ------------------------------ ChromaDB setup ------------------------------ | |
| CHROMA_PATH = os.path.join(os.getcwd(), "chroma_db") | |
| client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| collection = client.get_or_create_collection(name="knowledge_base") | |
| # ------------------------------ Embedding model ------------------------------ | |
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| # ------------------------------ BM25 (lightweight) ------------------------------ | |
| BM25_INDEX_FILE = os.path.join(CHROMA_PATH, "bm25_index.pkl") | |
| bm25_docs: List[Dict[str, Any]] = [] | |
| bm25_inverted: Dict[str, List[int]] = {} | |
| bm25_df: Dict[str, int] = {} | |
| bm25_avgdl: float = 0.0 | |
| bm25_ready: bool = False | |
| BM25_K1 = 1.5 | |
| BM25_B = 0.75 | |
| # ------------------------------ Utilities ------------------------------ | |
| def _tokenize(text: str) -> List[str]: | |
| if not text: | |
| return [] | |
| text = text.lower() | |
| return re.findall(r"[a-z0-9]+", text) | |
| def _normalize_query(q: str) -> str: | |
| q = (q or "").strip().lower() | |
| q = re.sub(r"[^\w\s]", " ", q) | |
| q = re.sub(r"\s+", " ", q).strip() | |
| return q | |
| def _tokenize_meta_value(val: Optional[str]) -> List[str]: | |
| return _tokenize(val or "") | |
| # ------------------------------ DOCX parsing & chunking ------------------------------ | |
| BULLET_RE = re.compile(r"^\s*(?:[\-\*\u2022]|\d+[.)])\s+", re.IGNORECASE) | |
| def _split_by_sections(doc: Document) -> List[Tuple[str, List[str]]]: | |
| sections: List[Tuple[str, List[str]]] = [] | |
| current_title = None | |
| current_paras: List[str] = [] | |
| for para in doc.paragraphs: | |
| text = (para.text or "").strip() | |
| style_name = (para.style.name if para.style else "") or "" | |
| is_heading = bool(re.match(r"Heading\s*\d+", style_name, flags=re.IGNORECASE)) | |
| if is_heading and text: | |
| if current_title or current_paras: | |
| sections.append((current_title or "Untitled Section", current_paras)) | |
| current_title = text | |
| current_paras = [] | |
| else: | |
| if text: | |
| current_paras.append(text) | |
| if current_title or current_paras: | |
| sections.append((current_title or "Untitled Section", current_paras)) | |
| if not sections: | |
| all_text = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] | |
| sections = [("Document", all_text)] | |
| return sections | |
| def _paragraphs_to_lines(paragraphs: List[str]) -> List[str]: | |
| """Preserve bullets/numbered list lines; split long paragraphs by sentence boundaries.""" | |
| lines: List[str] = [] | |
| for p in (paragraphs or []): | |
| p = (p or "").strip() | |
| if not p: | |
| continue | |
| if BULLET_RE.match(p): | |
| lines.append(p) | |
| continue | |
| parts = [s.strip() for s in re.split(r"(?<=[.!?])\s+", p) if s.strip()] | |
| lines.extend(parts) | |
| return lines | |
| def _chunk_text_with_context(doc_title: str, section_title: str, paragraphs: List[str], max_words: int = 160) -> List[str]: | |
| """Smaller chunks (~160 words), bullet-aware.""" | |
| lines = _paragraphs_to_lines(paragraphs) | |
| chunks: List[str] = [] | |
| current: List[str] = [] | |
| current_len = 0 | |
| for ln in lines: | |
| w = ln.split() | |
| if current_len + len(w) > max_words or (BULLET_RE.match(ln) and current): | |
| chunk = " ".join(current).strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| current = [ln] | |
| current_len = len(w) | |
| else: | |
| current.append(ln) | |
| current_len += len(w) | |
| if current: | |
| chunk = " ".join(current).strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| if not chunks: | |
| body = " ".join(lines).strip() | |
| if body: | |
| chunks = [body] | |
| return chunks | |
| # ------------------------------ Intent & Module tagging ------------------------------ | |
| SECTION_STEPS_HINTS = ["process steps", "procedure", "how to", "workflow", "instructions", "steps"] | |
| SECTION_ERRORS_HINTS = ["common errors", "resolution", "troubleshooting", "known issues", "common issues", "escalation", "escalation path", "permissions", "access"] | |
| PERMISSION_TERMS = [ | |
| "permission", "permissions", "access", "access right", "authorization", "authorisation", | |
| "role", "role access", "role mapping", "security", "security profile", "privilege", "insufficient", | |
| "not allowed", "not authorized", "denied", "restrict" | |
| ] | |
| ERROR_TERMS = ["error", "issue", "fail", "failure", "not working", "cannot", "can't", "mismatch", "locked", "wrong", "denied"] | |
| STEP_VERBS = ["navigate", "select", "scan", "verify", "confirm", "print", "move", "complete", "click", "open", "choose", "enter", "update", "save", "delete", "create", "attach", "assign"] | |
| MODULE_VOCAB = { | |
| "receiving": [ | |
| "receive", "receiving", "inbound receiving", "inbound", "goods receipt", "grn", | |
| "asn receiving", "unload", "check-in", "dock check-in" | |
| ], | |
| "appointments": [ | |
| "appointment", "appointments", "schedule", "scheduling", "slot", "dock door", | |
| "appointment creation", "appointment details" | |
| ], | |
| "picking": ["pick", "picking", "pick release", "wave", "allocation"], | |
| "putaway": ["putaway", "staging", "put away", "location assignment"], | |
| "shipping": ["shipping", "ship confirm", "outbound", "load", "trailer"], | |
| "inventory": ["inventory", "adjustment", "cycle count", "count", "uom"], | |
| "replenishment": ["replenishment", "replenish"], | |
| } | |
| def _infer_intent_tag(section_title: str) -> str: | |
| st = (section_title or "").lower() | |
| if any(k in st for k in SECTION_STEPS_HINTS): | |
| return "steps" | |
| if any(k in st for k in SECTION_ERRORS_HINTS): | |
| return "errors" | |
| if "pre" in st and "requisite" in st: | |
| return "prereqs" | |
| if any(k in st for k in ["purpose", "overview", "introduction"]): | |
| return "purpose" | |
| if any(k in st for k in ["inbound receiving", "receiving", "goods receipt", "grn"]): | |
| return "steps" | |
| if any(k in st for k in ["appointment", "appointments", "schedule", "scheduling"]): | |
| return "steps" | |
| return "neutral" | |
| def _derive_semantic_intent_from_text(text: str) -> Tuple[str, List[str]]: | |
| t = (text or "").lower() | |
| tags: List[str] = [] | |
| intent = "neutral" | |
| if any(term in t for term in PERMISSION_TERMS): | |
| intent = "errors"; tags.append("permissions") | |
| if "role" in t: tags.append("role_access") | |
| if "security" in t: tags.append("security") | |
| if intent == "neutral" and any(term in t for term in ERROR_TERMS): | |
| intent = "errors"; tags.append("errors") | |
| if intent == "neutral" and any(v in t for v in STEP_VERBS): | |
| intent = "steps"; tags.append("procedure") | |
| return intent, list(set(tags)) | |
| def _derive_module_tags(text: str, filename: str, section_title: str) -> List[str]: | |
| tokens = " ".join([filename or "", section_title or "", text or ""]).lower() | |
| found = [] | |
| for mod, syns in MODULE_VOCAB.items(): | |
| if any(s in tokens for s in syns): | |
| found.append(mod) | |
| if not found: | |
| if "inventory" in tokens or "adjust" in tokens or "uom" in tokens or "cycle" in tokens: | |
| found = ["inventory"] | |
| elif "receive" in tokens or "inbound" in tokens or "goods receipt" in tokens or "grn" in tokens: | |
| found = ["receiving"] | |
| elif "appointment" in tokens or "schedule" in tokens or "dock" in tokens: | |
| found = ["appointments"] | |
| return list(sorted(set(found))) | |
| # ------------------------------ Ingestion ------------------------------ | |
| def ingest_documents(folder_path: str) -> None: | |
| print(f"[KB] Checking folder: {folder_path}") | |
| files = [f for f in os.listdir(folder_path) if f.lower().endswith('.docx')] | |
| print(f"[KB] Found {len(files)} Word files: {files}") | |
| if not files: | |
| print("[KB] WARNING: No .docx files found. Please check the folder path.") | |
| return | |
| global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready | |
| bm25_docs, bm25_inverted, bm25_df = [], {}, {} | |
| bm25_avgdl, bm25_ready = 0.0, False | |
| for file in files: | |
| file_path = os.path.join(folder_path, file) | |
| doc_title = os.path.splitext(file)[0] | |
| doc = Document(file_path) | |
| sections = _split_by_sections(doc) | |
| total_chunks = 0 | |
| for s_idx, (section_title, paras) in enumerate(sections): | |
| chunks = _chunk_text_with_context(doc_title, section_title, paras, max_words=160) | |
| total_chunks += len(chunks) | |
| base_intent = _infer_intent_tag(section_title) | |
| for c_idx, chunk in enumerate(chunks): | |
| derived_intent, topic_tags = _derive_semantic_intent_from_text(chunk) | |
| final_intent = base_intent | |
| if derived_intent == "errors": | |
| final_intent = "errors" | |
| elif base_intent == "neutral" and derived_intent in ("steps", "prereqs"): | |
| final_intent = derived_intent | |
| module_tags = _derive_module_tags(chunk, file, section_title) | |
| embedding = model.encode(chunk).tolist() | |
| doc_id = f"{file}:{s_idx}:{c_idx}" | |
| meta = { | |
| "filename": file, | |
| "section": section_title, | |
| "chunk_index": c_idx, | |
| "title": doc_title, | |
| "collection": "SOP", | |
| "intent_tag": final_intent, | |
| "topic_tags": ", ".join(topic_tags) if topic_tags else "", | |
| "module_tags": ", ".join(module_tags) if module_tags else "", | |
| } | |
| try: | |
| collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta]) | |
| except Exception: | |
| try: | |
| collection.delete(ids=[doc_id]) | |
| collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta]) | |
| except Exception as e2: | |
| print(f"[KB] ERROR: Upsert failed for {doc_id}: {e2}") | |
| tokens = _tokenize(chunk) | |
| tf: Dict[str, int] = {} | |
| for tkn in tokens: | |
| tf[tkn] = tf.get(tkn, 0) + 1 | |
| idx = len(bm25_docs) | |
| bm25_docs.append({ | |
| "id": doc_id, | |
| "text": chunk, | |
| "tokens": tokens, | |
| "tf": tf, | |
| "length": len(tokens), | |
| "meta": meta, | |
| }) | |
| seen = set() | |
| for term in tf.keys(): | |
| bm25_inverted.setdefault(term, []).append(idx) | |
| if term not in seen: | |
| bm25_df[term] = bm25_df.get(term, 0) + 1 | |
| seen.add(term) | |
| print(f"[KB] Ingested {file} → {total_chunks} chunks") | |
| N = len(bm25_docs) | |
| if N > 0: | |
| bm25_avgdl = sum(d["length"] for d in bm25_docs) / float(N) | |
| bm25_ready = True | |
| payload = { | |
| "bm25_docs": bm25_docs, | |
| "bm25_inverted": bm25_inverted, | |
| "bm25_df": bm25_df, | |
| "bm25_avgdl": bm25_avgdl, | |
| "BM25_K1": BM25_K1, | |
| "BM25_B": BM25_B, | |
| } | |
| os.makedirs(CHROMA_PATH, exist_ok=True) | |
| with open(BM25_INDEX_FILE, "wb") as f: | |
| pickle.dump(payload, f) | |
| print(f"[KB] BM25 index saved: {BM25_INDEX_FILE}") | |
| print(f"[KB] Documents ingested. Total entries in Chroma: {collection.count()}") | |
| # ------------------------------ BM25 load ------------------------------ | |
| def _load_bm25_index() -> None: | |
| global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready | |
| if not os.path.exists(BM25_INDEX_FILE): | |
| return | |
| try: | |
| with open(BM25_INDEX_FILE, "rb") as f: | |
| payload = pickle.load(f) | |
| bm25_docs = payload.get("bm25_docs", []) | |
| bm25_inverted = payload.get("bm25_inverted", {}) | |
| bm25_df = payload.get("bm25_df", {}) | |
| bm25_avgdl = payload.get("bm25_avgdl", 0.0) | |
| bm25_ready = len(bm25_docs) > 0 | |
| if bm25_ready: | |
| print(f"[KB] BM25 index loaded: {BM25_INDEX_FILE} (docs={len(bm25_docs)})") | |
| except Exception as e: | |
| print(f"[KB] WARNING: Could not load BM25 index: {e}") | |
| _load_bm25_index() | |
| # ------------------------------ BM25 search ------------------------------ | |
| def _bm25_score_for_doc(query_terms: List[str], doc_idx: int) -> float: | |
| if not bm25_ready or doc_idx < 0 or doc_idx >= len(bm25_docs): | |
| return 0.0 | |
| doc = bm25_docs[doc_idx] | |
| score = 0.0 | |
| dl = doc["length"] or 1 | |
| for term in query_terms: | |
| df = bm25_df.get(term, 0) | |
| if df == 0: | |
| continue | |
| tf = doc["tf"].get(term, 0) | |
| if tf == 0: | |
| continue | |
| N = len(bm25_docs) | |
| idf_ratio = ((N - df + 0.5) / (df + 0.5)) | |
| try: | |
| import math | |
| idf = math.log(idf_ratio + 1.0) | |
| except Exception: | |
| idf = 1.0 | |
| denom = tf + BM25_K1 * (1 - BM25_B + BM25_B * (dl / (bm25_avgdl or 1.0))) | |
| score += idf * (((tf * (BM25_K1 + 1)) / (denom or 1.0))) | |
| return score | |
| def bm25_search(query: str, top_k: int = 50) -> List[Tuple[int, float]]: | |
| if not bm25_ready: | |
| return [] | |
| norm = _normalize_query(query) | |
| q_terms = _tokenize(norm) | |
| if not q_terms: | |
| return [] | |
| candidates = set() | |
| for t in q_terms: | |
| for idx in bm25_inverted.get(t, []): | |
| candidates.add(idx) | |
| if not candidates: | |
| candidates = set(range(len(bm25_docs))) | |
| scored = [] | |
| for idx in candidates: | |
| s = _bm25_score_for_doc(q_terms, idx) | |
| if s > 0: | |
| scored.append((idx, s)) | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| return scored[:top_k] | |
| # ------------------------------ Semantic-only ------------------------------ | |
| def search_knowledge_base(query: str, top_k: int = 10) -> dict: | |
| query_embedding = model.encode(query).tolist() | |
| res = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=top_k, | |
| include=['documents', 'metadatas', 'distances'] # no 'ids' | |
| ) | |
| documents = (res.get("documents", [[]]) or [[]])[0] | |
| metadatas = (res.get("metadatas", [[]]) or [[]])[0] | |
| distances = (res.get("distances", [[]]) or [[]])[0] | |
| # Synthesize IDs from metadata (filename:section:chunk_index) | |
| ids: List[str] = [] | |
| if documents: | |
| synthesized = [] | |
| for i, m in enumerate(metadatas): | |
| fn = (m or {}).get("filename", "unknown") | |
| sec = (m or {}).get("section", "section") | |
| idx = (m or {}).get("chunk_index", i) | |
| synthesized.append(f"{fn}:{sec}:{idx}") | |
| ids = synthesized | |
| print(f"[KB] search → {len(documents)} docs (top_k={top_k}); first distance: {distances[0] if distances else 'n/a'}; ids synthesized={len(ids)}") | |
| return { | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| "distances": distances, | |
| "ids": ids, | |
| } | |
| # ------------------------------ Hybrid search (generic + intent-aware) ------------------------------ | |
| ACTION_SYNONYMS = { | |
| "create": ["create", "creation", "add", "new", "generate"], | |
| "update": ["update", "modify", "change", "edit"], | |
| "delete": ["delete", "remove"], | |
| "navigate": ["navigate", "go to", "open"], | |
| } | |
| ERROR_INTENT_TERMS = [ | |
| "error", "issue", "fail", "not working", "resolution", "fix", | |
| "permission", "permissions", "access", "no access", "authorization", "authorisation", | |
| "role", "role mapping", "not authorized", "permission denied", "insufficient privileges", | |
| "escalation", "escalation path", "access right", "mismatch", "locked", "wrong" | |
| ] | |
| def _detect_user_intent(query: str) -> str: | |
| q = (query or '').lower() | |
| if any(k in q for k in ERROR_INTENT_TERMS): | |
| return 'errors' | |
| # Treat 'next step' style queries as steps intent | |
| NEXT_TERMS = ('next step','what next','whats next','then what','following step','continue','after','proceed') | |
| if any(k in q for k in NEXT_TERMS): | |
| return 'steps' | |
| if any(k in q for k in ['steps','procedure','how to','navigate','process','do','perform']): | |
| return 'steps' | |
| if any(k in q for k in ['pre-requisite','prerequisites','requirement','requirements']): | |
| return 'prereqs' | |
| if any(k in q for k in ['purpose','overview','introduction']): | |
| return 'purpose' | |
| return 'neutral' | |
| def _extract_actions(query: str) -> List[str]: | |
| q = (query or "").lower() | |
| found = [] | |
| ACTION_SYNONYMS = { | |
| "create": ("create", "creation", "add", "new", "generate", "setup", "set up", "register"), | |
| "update": ("update", "modify", "change", "edit", "amend"), | |
| "delete": ("delete", "remove", "cancel", "void"), | |
| "navigate": ("navigate", "go to", "open"), | |
| } | |
| # direct synonyms | |
| for act, syns in ACTION_SYNONYMS.items(): | |
| if any(s in q for s in syns): | |
| found.append(act) | |
| # extra cues | |
| if "steps for" in q or "procedure for" in q or "how to" in q: | |
| # pick the action that follows these cues | |
| for act, syns in ACTION_SYNONYMS.items(): | |
| if any(("steps for " + s) in q for s in syns) or any(("procedure for " + s) in q for s in syns): | |
| found.append(act) | |
| return sorted(set(found)) or [] | |
| def _extract_modules_from_query(query: str) -> List[str]: | |
| q = (query or "").lower() | |
| found = [] | |
| for mod, syns in MODULE_VOCAB.items(): | |
| if any(s in q for s in syns): | |
| found.append(mod) | |
| return sorted(set(found)) | |
| def _action_weight(text: str, actions: List[str]) -> float: | |
| if not actions: | |
| return 0.0 | |
| t = (text or "").lower() | |
| score = 0.0 | |
| for act in actions: | |
| for syn in ACTION_SYNONYMS.get(act, [act]): | |
| if syn in t: | |
| score += 1.0 | |
| conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": []} | |
| for act in actions: | |
| for bad in conflicts.get(act, []): | |
| for syn in ACTION_SYNONYMS.get(bad, [bad]): | |
| if syn in t: | |
| score -= 0.8 | |
| return score | |
| def _module_weight(meta: Dict[str, Any], user_modules: List[str]) -> float: | |
| if not user_modules: | |
| return 0.0 | |
| raw = (meta or {}).get("module_tags", "") or "" | |
| doc_modules = [m.strip() for m in raw.split(",") if m.strip()] if isinstance(raw, str) else (raw or []) | |
| overlap = len(set(user_modules) & set(doc_modules)) | |
| if overlap == 0: | |
| return -0.8 | |
| return 0.7 * overlap | |
| def _intent_weight(meta: dict, user_intent: str) -> float: | |
| tag = (meta or {}).get("intent_tag", "neutral") | |
| if user_intent == "neutral": | |
| return 0.0 | |
| if tag == user_intent: | |
| return 1.0 | |
| if tag in ["purpose", "prereqs"] and user_intent in ["steps", "errors"]: | |
| return -0.6 | |
| st = ((meta or {}).get("section", "") or "").lower() | |
| topics = (meta or {}).get("topic_tags", "") or "" | |
| topic_list = [t.strip() for t in topics.split(",") if t.strip()] | |
| if user_intent == "errors" and ( | |
| any(k in st for k in ["common errors", "known issues", "common issues", "errors", "escalation", "permissions", "access"]) | |
| or ("permissions" in topic_list) | |
| ): | |
| return 1.10 | |
| if user_intent == "steps" and any(k in st for k in ["process steps", "procedure", "instructions", "workflow"]): | |
| return 0.75 | |
| return -0.2 | |
| def _meta_overlap(meta: Dict[str, Any], q_terms: List[str]) -> float: | |
| fn_tokens = _tokenize_meta_value(meta.get("filename")) | |
| title_tokens = _tokenize_meta_value(meta.get("title")) | |
| section_tokens = _tokenize_meta_value(meta.get("section")) | |
| topic_tokens = _tokenize_meta_value((meta.get("topic_tags") or "")) | |
| module_tokens = _tokenize_meta_value((meta.get("module_tags") or "")) | |
| meta_tokens = set(fn_tokens + title_tokens + section_tokens + topic_tokens + module_tokens) | |
| if not meta_tokens or not q_terms: | |
| return 0.0 | |
| qset = set(q_terms) | |
| inter = len(meta_tokens & qset) | |
| return inter / max(1, len(qset)) | |
| def _make_ngrams(tokens: List[str], n: int) -> List[str]: | |
| return [" ".join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)] | |
| def _phrase_boost_score(text: str, q_terms: List[str]) -> float: | |
| if not text or not q_terms: | |
| return 0.0 | |
| low = (text or "").lower() | |
| bigrams = _make_ngrams(q_terms, 2) | |
| trigrams = _make_ngrams(q_terms, 3) | |
| score = 0.0 | |
| for bg in bigrams: | |
| if bg and bg in low: | |
| score += 0.40 | |
| for tg in trigrams: | |
| if tg and tg in low: | |
| score += 0.70 | |
| return min(score, 2.0) | |
| def _literal_query_match_boost(text: str, query_norm: str) -> float: | |
| t = (text or "").lower() | |
| q = (query_norm or "").lower() | |
| boost = 0.0 | |
| if q and q in t: | |
| boost += 0.8 | |
| toks = [tok for tok in q.split() if len(tok) > 2] | |
| bigrams = _make_ngrams(toks, 2) | |
| for bg in bigrams: | |
| if bg in t: | |
| boost += 0.8 | |
| break | |
| return min(boost, 1.6) | |
| def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict: | |
| """ | |
| Hybrid retrieval (embeddings + BM25) with intent-, action-, module-, and phrase-aware reranking. | |
| Returns top items plus doc-level prior and intent for downstream formatting. | |
| """ | |
| norm_query = _normalize_query(query) | |
| q_terms = _tokenize(norm_query) | |
| user_intent = _detect_user_intent(query) | |
| actions = _extract_actions(query) | |
| user_modules = _extract_modules_from_query(query) | |
| # semantic (embeddings) search via Chroma | |
| sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 40)) | |
| sem_docs = sem_res.get("documents", []) | |
| sem_metas = sem_res.get("metadatas", []) | |
| sem_dists = sem_res.get("distances", []) | |
| sem_ids = sem_res.get("ids", []) | |
| def dist_to_sim(d: Optional[float]) -> float: | |
| if d is None: | |
| return 0.0 | |
| try: | |
| return 1.0 / (1.0 + float(d)) | |
| except Exception: | |
| return 0.0 | |
| sem_sims = [dist_to_sim(d) for d in sem_dists] | |
| # BM25 search | |
| bm25_hits = bm25_search(norm_query, top_k=max(80, top_k * 6)) | |
| bm25_max = max([s for _, s in bm25_hits], default=1.0) | |
| bm25_norm_pairs = [(idx, (score / bm25_max) if bm25_max > 0 else 0.0) for idx, score in bm25_hits] | |
| bm25_id_to_norm, bm25_id_to_text, bm25_id_to_meta = {}, {}, {} | |
| for idx, nscore in bm25_norm_pairs: | |
| d = bm25_docs[idx] | |
| bm25_id_to_norm[d["id"]] = nscore | |
| bm25_id_to_text[d["id"]] = d["text"] | |
| bm25_id_to_meta[d["id"]] = d["meta"] | |
| # union of candidate IDs (semantic + bm25) | |
| union_ids = set(sem_ids) | set(bm25_id_to_norm.keys()) | |
| # weights | |
| gamma = 0.30 # meta overlap | |
| delta = 0.55 # intent boost | |
| epsilon = 0.30 # action weight | |
| zeta = 0.65 # module weight | |
| eta = 0.50 # phrase-level boost | |
| theta = 0.00 # optional heading alignment bonus not used | |
| iota = 0.60 # literal query match boost | |
| combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = [] | |
| for cid in union_ids: | |
| # pick semantic fields if present; fallback to bm25 | |
| if cid in sem_ids: | |
| pos = sem_ids.index(cid) | |
| sem_sim = sem_sims[pos] if pos < len(sem_sims) else 0.0 | |
| sem_dist = sem_dists[pos] if pos < len(sem_dists) else None | |
| sem_text = sem_docs[pos] if pos < len(sem_docs) else "" | |
| sem_meta = sem_metas[pos] if pos < len(sem_metas) else {} | |
| else: | |
| sem_sim, sem_dist, sem_text, sem_meta = 0.0, None, "", {} | |
| bm25_sim = bm25_id_to_norm.get(cid, 0.0) | |
| bm25_text = bm25_id_to_text.get(cid, "") | |
| bm25_meta = bm25_id_to_meta.get(cid, {}) | |
| text = sem_text if sem_text else bm25_text | |
| meta = sem_meta if sem_meta else bm25_meta | |
| m_overlap = _meta_overlap(meta, q_terms) | |
| intent_boost = _intent_weight(meta, user_intent) | |
| act_wt = _action_weight(text, actions) | |
| mod_wt = _module_weight(meta, user_modules) | |
| phrase_wt = _phrase_boost_score(text, q_terms) | |
| literal_wt = _literal_query_match_boost(text, norm_query) | |
| final_score = ( | |
| alpha * sem_sim | |
| + beta * bm25_sim | |
| + gamma * m_overlap | |
| + delta * intent_boost | |
| + epsilon * act_wt | |
| + zeta * mod_wt | |
| + eta * phrase_wt | |
| + theta * 0.0 | |
| + iota * literal_wt | |
| ) | |
| combined_records_ext.append( | |
| (cid, final_score, (sem_dist if sem_dist is not None else 999.0), text, meta, | |
| m_overlap, intent_boost, act_wt, mod_wt, phrase_wt, 0.0, literal_wt) | |
| ) | |
| # exact-match rerank for errors (push lines containing query phrases) | |
| if user_intent == "errors": | |
| exact_hits = [] | |
| toks = [tok for tok in norm_query.split() if len(tok) > 2] | |
| bigrams = _make_ngrams(toks, 2) | |
| for rec in combined_records_ext: | |
| text_lower = (rec[3] or "").lower() | |
| if norm_query and norm_query in text_lower: | |
| exact_hits.append(rec) | |
| continue | |
| if any(bg in text_lower for bg in bigrams): | |
| exact_hits.append(rec) | |
| if exact_hits: | |
| rest = [r for r in combined_records_ext if r not in exact_hits] | |
| exact_hits.sort(key=lambda x: x[1], reverse=True) | |
| rest.sort(key=lambda x: x[1], reverse=True) | |
| combined_records_ext = exact_hits + rest | |
| # doc-level prior: prefer docs with more aligned chunks | |
| from collections import defaultdict as _dd | |
| doc_groups: Dict[str, List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]] = _dd(list) | |
| for rec in combined_records_ext: | |
| meta = rec[4] or {} | |
| fn = meta.get("filename", "unknown") | |
| doc_groups[fn].append(rec) | |
| def doc_prior(recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]) -> float: | |
| total_score = sum(r[1] for r in recs) | |
| total_overlap = sum(r[5] for r in recs) | |
| total_intent = sum(max(0.0, r[6]) for r in recs) | |
| total_action = sum(max(0.0, r[7]) for r in recs) | |
| total_module = sum(r[8] for r in recs) | |
| total_phrase = sum(r[9] for r in recs) | |
| total_literal = sum(r[11] for r in recs) | |
| total_penalty = sum(min(0.0, r[6]) for r in recs) + sum(min(0.0, r[7]) for r in recs) | |
| errors_section_bonus = 0.0 | |
| if any("error" in ((r[4] or {}).get("section", "")).lower() or | |
| "known issues" in ((r[4] or {}).get("section", "")).lower() or | |
| "common issues" in ((r[4] or {}).get("section", "")).lower() for r in recs): | |
| errors_section_bonus = 0.5 | |
| return ( | |
| total_score | |
| + 0.4 * total_overlap | |
| + 0.7 * total_intent | |
| + 0.5 * total_action | |
| + 0.8 * total_module | |
| + 0.6 * total_phrase | |
| + 0.7 * total_literal | |
| + errors_section_bonus | |
| + 0.3 * total_penalty | |
| ) | |
| best_doc, best_doc_prior = None, -1.0 | |
| for fn, recs in doc_groups.items(): | |
| p = doc_prior(recs) | |
| if p > best_doc_prior: | |
| best_doc_prior, best_doc = p, fn | |
| best_recs = sorted(doc_groups.get(best_doc, []), key=lambda x: x[1], reverse=True) | |
| other_recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = [] | |
| for fn, recs in doc_groups.items(): | |
| if fn == best_doc: | |
| continue | |
| other_recs.extend(recs) | |
| other_recs.sort(key=lambda x: x[1], reverse=True) | |
| reordered = best_recs + other_recs | |
| top = reordered[:top_k] | |
| documents = [t[3] for t in top] | |
| metadatas = [t[4] for t in top] | |
| distances = [t[2] for t in top] | |
| ids = [t[0] for t in top] | |
| combined_scores = [t[1] for t in top] | |
| return { | |
| "documents": documents, | |
| "metadatas": metadatas, | |
| "distances": distances, | |
| "ids": ids, | |
| "combined_scores": combined_scores, | |
| "best_doc": best_doc, | |
| "best_doc_prior": best_doc_prior, | |
| "user_intent": user_intent, | |
| "actions": actions, | |
| } | |
| # ------------------------------ Section fetch helpers ------------------------------ | |
| def get_section_text(filename: str, section: str) -> str: | |
| texts: List[str] = [] | |
| for d in bm25_docs: | |
| m = d.get("meta", {}) | |
| if m.get("filename") == filename and m.get("section") == section: | |
| t = (d.get("text") or "").strip() | |
| if t: | |
| texts.append(t) | |
| return "\n\n".join(texts).strip() | |
| def get_best_steps_section_text(filename: str) -> str: | |
| texts: List[str] = [] | |
| for d in bm25_docs: | |
| m = d.get("meta", {}) | |
| if m.get("filename") == filename and (m.get("intent_tag") == "steps"): | |
| t = (d.get("text") or "").strip() | |
| if t: | |
| texts.append(t) | |
| return "\n\n".join(texts).strip() | |
| def get_best_errors_section_text(filename: str) -> str: | |
| texts: List[str] = [] | |
| for d in bm25_docs: | |
| m = d.get("meta", {}) | |
| sec = (m.get("section") or "").lower() | |
| topics = (m.get("topic_tags") or "") | |
| topic_list = [t.strip() for t in topics.split(",") if t.strip()] | |
| if m.get("filename") == filename and ( | |
| m.get("intent_tag") == "errors" | |
| or "error" in sec | |
| or "escalation" in sec | |
| or "permission" in sec | |
| or "access" in sec | |
| or "known issues" in sec | |
| or "common issues" in sec | |
| or "errors" in sec | |
| or ("permissions" in topic_list) | |
| ): | |
| t = (d.get("text") or "").strip() | |
| if t: | |
| texts.append(t) | |
| return "\n\n".join(texts).strip() | |
| def get_escalation_text(filename: str) -> str: | |
| """ | |
| Return concatenated text from any 'Escalation' section in the given SOP file. | |
| Works across future SOPs—only relies on the heading name containing 'escalation'. | |
| """ | |
| texts: List[str] = [] | |
| for d in bm25_docs: | |
| m = d.get("meta", {}) | |
| if m.get("filename") == filename: | |
| sec = (m.get("section") or "").lower() | |
| if "escalation" in sec: | |
| t = (d.get("text") or "").strip() | |
| if t: | |
| texts.append(t) | |
| return "\n\n".join(texts).strip() | |
| # ------------------------------ Admin helpers ------------------------------ | |
| def get_kb_runtime_info() -> Dict[str, Any]: | |
| return { | |
| "chroma_path": CHROMA_PATH, | |
| "chroma_exists": os.path.isdir(CHROMA_PATH), | |
| "bm25_index_file": BM25_INDEX_FILE, | |
| "bm25_index_exists": os.path.isfile(BM25_INDEX_FILE), | |
| "collection_count": collection.count(), | |
| "bm25_ready": bm25_ready, | |
| } | |
| def reset_kb(folder_path: str) -> Dict[str, Any]: | |
| result = {"status": "OK", "message": "KB reset and re-ingested"} | |
| try: | |
| try: | |
| client.delete_collection(name="knowledge_base") | |
| except Exception: | |
| pass | |
| global collection | |
| collection = client.get_or_create_collection(name="knowledge_base") | |
| try: | |
| if os.path.isfile(BM25_INDEX_FILE): | |
| os.remove(BM25_INDEX_FILE) | |
| except Exception as e: | |
| result.setdefault("warnings", []).append(f"bm25 index delete: {e}") | |
| os.makedirs(CHROMA_PATH, exist_ok=True) | |
| ingest_documents(folder_path) | |
| result["info"] = get_kb_runtime_info() | |
| return result | |
| except Exception as e: | |
| return {"status": "ERROR", "error": f"{e}", "info": get_kb_runtime_info()} | |