Chatbot-Backend / services /kb_creation.py
srilakshu012456's picture
Update services/kb_creation.py
4345c26 verified
# services/kb_creation.py
import os
import re
import pickle
from typing import List, Dict, Any, Tuple, Optional
from docx import Document
from sentence_transformers import SentenceTransformer
import chromadb
# ------------------------------ ChromaDB setup ------------------------------
CHROMA_PATH = os.path.join(os.getcwd(), "chroma_db")
client = chromadb.PersistentClient(path=CHROMA_PATH)
collection = client.get_or_create_collection(name="knowledge_base")
# ------------------------------ Embedding model ------------------------------
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# ------------------------------ BM25 (lightweight) ------------------------------
BM25_INDEX_FILE = os.path.join(CHROMA_PATH, "bm25_index.pkl")
bm25_docs: List[Dict[str, Any]] = []
bm25_inverted: Dict[str, List[int]] = {}
bm25_df: Dict[str, int] = {}
bm25_avgdl: float = 0.0
bm25_ready: bool = False
BM25_K1 = 1.5
BM25_B = 0.75
# ------------------------------ Utilities ------------------------------
def _tokenize(text: str) -> List[str]:
if not text:
return []
text = text.lower()
return re.findall(r"[a-z0-9]+", text)
def _normalize_query(q: str) -> str:
q = (q or "").strip().lower()
q = re.sub(r"[^\w\s]", " ", q)
q = re.sub(r"\s+", " ", q).strip()
return q
def _tokenize_meta_value(val: Optional[str]) -> List[str]:
return _tokenize(val or "")
# ------------------------------ DOCX parsing & chunking ------------------------------
BULLET_RE = re.compile(r"^\s*(?:[\-\*\u2022]|\d+[.)])\s+", re.IGNORECASE)
def _split_by_sections(doc: Document) -> List[Tuple[str, List[str]]]:
sections: List[Tuple[str, List[str]]] = []
current_title = None
current_paras: List[str] = []
for para in doc.paragraphs:
text = (para.text or "").strip()
style_name = (para.style.name if para.style else "") or ""
is_heading = bool(re.match(r"Heading\s*\d+", style_name, flags=re.IGNORECASE))
if is_heading and text:
if current_title or current_paras:
sections.append((current_title or "Untitled Section", current_paras))
current_title = text
current_paras = []
else:
if text:
current_paras.append(text)
if current_title or current_paras:
sections.append((current_title or "Untitled Section", current_paras))
if not sections:
all_text = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
sections = [("Document", all_text)]
return sections
def _paragraphs_to_lines(paragraphs: List[str]) -> List[str]:
"""Preserve bullets/numbered list lines; split long paragraphs by sentence boundaries."""
lines: List[str] = []
for p in (paragraphs or []):
p = (p or "").strip()
if not p:
continue
if BULLET_RE.match(p):
lines.append(p)
continue
parts = [s.strip() for s in re.split(r"(?<=[.!?])\s+", p) if s.strip()]
lines.extend(parts)
return lines
def _chunk_text_with_context(doc_title: str, section_title: str, paragraphs: List[str], max_words: int = 160) -> List[str]:
"""Smaller chunks (~160 words), bullet-aware."""
lines = _paragraphs_to_lines(paragraphs)
chunks: List[str] = []
current: List[str] = []
current_len = 0
for ln in lines:
w = ln.split()
if current_len + len(w) > max_words or (BULLET_RE.match(ln) and current):
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
current = [ln]
current_len = len(w)
else:
current.append(ln)
current_len += len(w)
if current:
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
if not chunks:
body = " ".join(lines).strip()
if body:
chunks = [body]
return chunks
# ------------------------------ Intent & Module tagging ------------------------------
SECTION_STEPS_HINTS = ["process steps", "procedure", "how to", "workflow", "instructions", "steps"]
SECTION_ERRORS_HINTS = ["common errors", "resolution", "troubleshooting", "known issues", "common issues", "escalation", "escalation path", "permissions", "access"]
PERMISSION_TERMS = [
"permission", "permissions", "access", "access right", "authorization", "authorisation",
"role", "role access", "role mapping", "security", "security profile", "privilege", "insufficient",
"not allowed", "not authorized", "denied", "restrict"
]
ERROR_TERMS = ["error", "issue", "fail", "failure", "not working", "cannot", "can't", "mismatch", "locked", "wrong", "denied"]
STEP_VERBS = ["navigate", "select", "scan", "verify", "confirm", "print", "move", "complete", "click", "open", "choose", "enter", "update", "save", "delete", "create", "attach", "assign"]
MODULE_VOCAB = {
"receiving": [
"receive", "receiving", "inbound receiving", "inbound", "goods receipt", "grn",
"asn receiving", "unload", "check-in", "dock check-in"
],
"appointments": [
"appointment", "appointments", "schedule", "scheduling", "slot", "dock door",
"appointment creation", "appointment details"
],
"picking": ["pick", "picking", "pick release", "wave", "allocation"],
"putaway": ["putaway", "staging", "put away", "location assignment"],
"shipping": ["shipping", "ship confirm", "outbound", "load", "trailer"],
"inventory": ["inventory", "adjustment", "cycle count", "count", "uom"],
"replenishment": ["replenishment", "replenish"],
}
def _infer_intent_tag(section_title: str) -> str:
st = (section_title or "").lower()
if any(k in st for k in SECTION_STEPS_HINTS):
return "steps"
if any(k in st for k in SECTION_ERRORS_HINTS):
return "errors"
if "pre" in st and "requisite" in st:
return "prereqs"
if any(k in st for k in ["purpose", "overview", "introduction"]):
return "purpose"
if any(k in st for k in ["inbound receiving", "receiving", "goods receipt", "grn"]):
return "steps"
if any(k in st for k in ["appointment", "appointments", "schedule", "scheduling"]):
return "steps"
return "neutral"
def _derive_semantic_intent_from_text(text: str) -> Tuple[str, List[str]]:
t = (text or "").lower()
tags: List[str] = []
intent = "neutral"
if any(term in t for term in PERMISSION_TERMS):
intent = "errors"; tags.append("permissions")
if "role" in t: tags.append("role_access")
if "security" in t: tags.append("security")
if intent == "neutral" and any(term in t for term in ERROR_TERMS):
intent = "errors"; tags.append("errors")
if intent == "neutral" and any(v in t for v in STEP_VERBS):
intent = "steps"; tags.append("procedure")
return intent, list(set(tags))
def _derive_module_tags(text: str, filename: str, section_title: str) -> List[str]:
tokens = " ".join([filename or "", section_title or "", text or ""]).lower()
found = []
for mod, syns in MODULE_VOCAB.items():
if any(s in tokens for s in syns):
found.append(mod)
if not found:
if "inventory" in tokens or "adjust" in tokens or "uom" in tokens or "cycle" in tokens:
found = ["inventory"]
elif "receive" in tokens or "inbound" in tokens or "goods receipt" in tokens or "grn" in tokens:
found = ["receiving"]
elif "appointment" in tokens or "schedule" in tokens or "dock" in tokens:
found = ["appointments"]
return list(sorted(set(found)))
# ------------------------------ Ingestion ------------------------------
def ingest_documents(folder_path: str) -> None:
print(f"[KB] Checking folder: {folder_path}")
files = [f for f in os.listdir(folder_path) if f.lower().endswith('.docx')]
print(f"[KB] Found {len(files)} Word files: {files}")
if not files:
print("[KB] WARNING: No .docx files found. Please check the folder path.")
return
global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready
bm25_docs, bm25_inverted, bm25_df = [], {}, {}
bm25_avgdl, bm25_ready = 0.0, False
for file in files:
file_path = os.path.join(folder_path, file)
doc_title = os.path.splitext(file)[0]
doc = Document(file_path)
sections = _split_by_sections(doc)
total_chunks = 0
for s_idx, (section_title, paras) in enumerate(sections):
chunks = _chunk_text_with_context(doc_title, section_title, paras, max_words=160)
total_chunks += len(chunks)
base_intent = _infer_intent_tag(section_title)
for c_idx, chunk in enumerate(chunks):
derived_intent, topic_tags = _derive_semantic_intent_from_text(chunk)
final_intent = base_intent
if derived_intent == "errors":
final_intent = "errors"
elif base_intent == "neutral" and derived_intent in ("steps", "prereqs"):
final_intent = derived_intent
module_tags = _derive_module_tags(chunk, file, section_title)
embedding = model.encode(chunk).tolist()
doc_id = f"{file}:{s_idx}:{c_idx}"
meta = {
"filename": file,
"section": section_title,
"chunk_index": c_idx,
"title": doc_title,
"collection": "SOP",
"intent_tag": final_intent,
"topic_tags": ", ".join(topic_tags) if topic_tags else "",
"module_tags": ", ".join(module_tags) if module_tags else "",
}
try:
collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta])
except Exception:
try:
collection.delete(ids=[doc_id])
collection.add(ids=[doc_id], embeddings=[embedding], documents=[chunk], metadatas=[meta])
except Exception as e2:
print(f"[KB] ERROR: Upsert failed for {doc_id}: {e2}")
tokens = _tokenize(chunk)
tf: Dict[str, int] = {}
for tkn in tokens:
tf[tkn] = tf.get(tkn, 0) + 1
idx = len(bm25_docs)
bm25_docs.append({
"id": doc_id,
"text": chunk,
"tokens": tokens,
"tf": tf,
"length": len(tokens),
"meta": meta,
})
seen = set()
for term in tf.keys():
bm25_inverted.setdefault(term, []).append(idx)
if term not in seen:
bm25_df[term] = bm25_df.get(term, 0) + 1
seen.add(term)
print(f"[KB] Ingested {file}{total_chunks} chunks")
N = len(bm25_docs)
if N > 0:
bm25_avgdl = sum(d["length"] for d in bm25_docs) / float(N)
bm25_ready = True
payload = {
"bm25_docs": bm25_docs,
"bm25_inverted": bm25_inverted,
"bm25_df": bm25_df,
"bm25_avgdl": bm25_avgdl,
"BM25_K1": BM25_K1,
"BM25_B": BM25_B,
}
os.makedirs(CHROMA_PATH, exist_ok=True)
with open(BM25_INDEX_FILE, "wb") as f:
pickle.dump(payload, f)
print(f"[KB] BM25 index saved: {BM25_INDEX_FILE}")
print(f"[KB] Documents ingested. Total entries in Chroma: {collection.count()}")
# ------------------------------ BM25 load ------------------------------
def _load_bm25_index() -> None:
global bm25_docs, bm25_inverted, bm25_df, bm25_avgdl, bm25_ready
if not os.path.exists(BM25_INDEX_FILE):
return
try:
with open(BM25_INDEX_FILE, "rb") as f:
payload = pickle.load(f)
bm25_docs = payload.get("bm25_docs", [])
bm25_inverted = payload.get("bm25_inverted", {})
bm25_df = payload.get("bm25_df", {})
bm25_avgdl = payload.get("bm25_avgdl", 0.0)
bm25_ready = len(bm25_docs) > 0
if bm25_ready:
print(f"[KB] BM25 index loaded: {BM25_INDEX_FILE} (docs={len(bm25_docs)})")
except Exception as e:
print(f"[KB] WARNING: Could not load BM25 index: {e}")
_load_bm25_index()
# ------------------------------ BM25 search ------------------------------
def _bm25_score_for_doc(query_terms: List[str], doc_idx: int) -> float:
if not bm25_ready or doc_idx < 0 or doc_idx >= len(bm25_docs):
return 0.0
doc = bm25_docs[doc_idx]
score = 0.0
dl = doc["length"] or 1
for term in query_terms:
df = bm25_df.get(term, 0)
if df == 0:
continue
tf = doc["tf"].get(term, 0)
if tf == 0:
continue
N = len(bm25_docs)
idf_ratio = ((N - df + 0.5) / (df + 0.5))
try:
import math
idf = math.log(idf_ratio + 1.0)
except Exception:
idf = 1.0
denom = tf + BM25_K1 * (1 - BM25_B + BM25_B * (dl / (bm25_avgdl or 1.0)))
score += idf * (((tf * (BM25_K1 + 1)) / (denom or 1.0)))
return score
def bm25_search(query: str, top_k: int = 50) -> List[Tuple[int, float]]:
if not bm25_ready:
return []
norm = _normalize_query(query)
q_terms = _tokenize(norm)
if not q_terms:
return []
candidates = set()
for t in q_terms:
for idx in bm25_inverted.get(t, []):
candidates.add(idx)
if not candidates:
candidates = set(range(len(bm25_docs)))
scored = []
for idx in candidates:
s = _bm25_score_for_doc(q_terms, idx)
if s > 0:
scored.append((idx, s))
scored.sort(key=lambda x: x[1], reverse=True)
return scored[:top_k]
# ------------------------------ Semantic-only ------------------------------
def search_knowledge_base(query: str, top_k: int = 10) -> dict:
query_embedding = model.encode(query).tolist()
res = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=['documents', 'metadatas', 'distances'] # no 'ids'
)
documents = (res.get("documents", [[]]) or [[]])[0]
metadatas = (res.get("metadatas", [[]]) or [[]])[0]
distances = (res.get("distances", [[]]) or [[]])[0]
# Synthesize IDs from metadata (filename:section:chunk_index)
ids: List[str] = []
if documents:
synthesized = []
for i, m in enumerate(metadatas):
fn = (m or {}).get("filename", "unknown")
sec = (m or {}).get("section", "section")
idx = (m or {}).get("chunk_index", i)
synthesized.append(f"{fn}:{sec}:{idx}")
ids = synthesized
print(f"[KB] search → {len(documents)} docs (top_k={top_k}); first distance: {distances[0] if distances else 'n/a'}; ids synthesized={len(ids)}")
return {
"documents": documents,
"metadatas": metadatas,
"distances": distances,
"ids": ids,
}
# ------------------------------ Hybrid search (generic + intent-aware) ------------------------------
ACTION_SYNONYMS = {
"create": ["create", "creation", "add", "new", "generate"],
"update": ["update", "modify", "change", "edit"],
"delete": ["delete", "remove"],
"navigate": ["navigate", "go to", "open"],
}
ERROR_INTENT_TERMS = [
"error", "issue", "fail", "not working", "resolution", "fix",
"permission", "permissions", "access", "no access", "authorization", "authorisation",
"role", "role mapping", "not authorized", "permission denied", "insufficient privileges",
"escalation", "escalation path", "access right", "mismatch", "locked", "wrong"
]
def _detect_user_intent(query: str) -> str:
q = (query or '').lower()
if any(k in q for k in ERROR_INTENT_TERMS):
return 'errors'
# Treat 'next step' style queries as steps intent
NEXT_TERMS = ('next step','what next','whats next','then what','following step','continue','after','proceed')
if any(k in q for k in NEXT_TERMS):
return 'steps'
if any(k in q for k in ['steps','procedure','how to','navigate','process','do','perform']):
return 'steps'
if any(k in q for k in ['pre-requisite','prerequisites','requirement','requirements']):
return 'prereqs'
if any(k in q for k in ['purpose','overview','introduction']):
return 'purpose'
return 'neutral'
def _extract_actions(query: str) -> List[str]:
q = (query or "").lower()
found = []
ACTION_SYNONYMS = {
"create": ("create", "creation", "add", "new", "generate", "setup", "set up", "register"),
"update": ("update", "modify", "change", "edit", "amend"),
"delete": ("delete", "remove", "cancel", "void"),
"navigate": ("navigate", "go to", "open"),
}
# direct synonyms
for act, syns in ACTION_SYNONYMS.items():
if any(s in q for s in syns):
found.append(act)
# extra cues
if "steps for" in q or "procedure for" in q or "how to" in q:
# pick the action that follows these cues
for act, syns in ACTION_SYNONYMS.items():
if any(("steps for " + s) in q for s in syns) or any(("procedure for " + s) in q for s in syns):
found.append(act)
return sorted(set(found)) or []
def _extract_modules_from_query(query: str) -> List[str]:
q = (query or "").lower()
found = []
for mod, syns in MODULE_VOCAB.items():
if any(s in q for s in syns):
found.append(mod)
return sorted(set(found))
def _action_weight(text: str, actions: List[str]) -> float:
if not actions:
return 0.0
t = (text or "").lower()
score = 0.0
for act in actions:
for syn in ACTION_SYNONYMS.get(act, [act]):
if syn in t:
score += 1.0
conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": []}
for act in actions:
for bad in conflicts.get(act, []):
for syn in ACTION_SYNONYMS.get(bad, [bad]):
if syn in t:
score -= 0.8
return score
def _module_weight(meta: Dict[str, Any], user_modules: List[str]) -> float:
if not user_modules:
return 0.0
raw = (meta or {}).get("module_tags", "") or ""
doc_modules = [m.strip() for m in raw.split(",") if m.strip()] if isinstance(raw, str) else (raw or [])
overlap = len(set(user_modules) & set(doc_modules))
if overlap == 0:
return -0.8
return 0.7 * overlap
def _intent_weight(meta: dict, user_intent: str) -> float:
tag = (meta or {}).get("intent_tag", "neutral")
if user_intent == "neutral":
return 0.0
if tag == user_intent:
return 1.0
if tag in ["purpose", "prereqs"] and user_intent in ["steps", "errors"]:
return -0.6
st = ((meta or {}).get("section", "") or "").lower()
topics = (meta or {}).get("topic_tags", "") or ""
topic_list = [t.strip() for t in topics.split(",") if t.strip()]
if user_intent == "errors" and (
any(k in st for k in ["common errors", "known issues", "common issues", "errors", "escalation", "permissions", "access"])
or ("permissions" in topic_list)
):
return 1.10
if user_intent == "steps" and any(k in st for k in ["process steps", "procedure", "instructions", "workflow"]):
return 0.75
return -0.2
def _meta_overlap(meta: Dict[str, Any], q_terms: List[str]) -> float:
fn_tokens = _tokenize_meta_value(meta.get("filename"))
title_tokens = _tokenize_meta_value(meta.get("title"))
section_tokens = _tokenize_meta_value(meta.get("section"))
topic_tokens = _tokenize_meta_value((meta.get("topic_tags") or ""))
module_tokens = _tokenize_meta_value((meta.get("module_tags") or ""))
meta_tokens = set(fn_tokens + title_tokens + section_tokens + topic_tokens + module_tokens)
if not meta_tokens or not q_terms:
return 0.0
qset = set(q_terms)
inter = len(meta_tokens & qset)
return inter / max(1, len(qset))
def _make_ngrams(tokens: List[str], n: int) -> List[str]:
return [" ".join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
def _phrase_boost_score(text: str, q_terms: List[str]) -> float:
if not text or not q_terms:
return 0.0
low = (text or "").lower()
bigrams = _make_ngrams(q_terms, 2)
trigrams = _make_ngrams(q_terms, 3)
score = 0.0
for bg in bigrams:
if bg and bg in low:
score += 0.40
for tg in trigrams:
if tg and tg in low:
score += 0.70
return min(score, 2.0)
def _literal_query_match_boost(text: str, query_norm: str) -> float:
t = (text or "").lower()
q = (query_norm or "").lower()
boost = 0.0
if q and q in t:
boost += 0.8
toks = [tok for tok in q.split() if len(tok) > 2]
bigrams = _make_ngrams(toks, 2)
for bg in bigrams:
if bg in t:
boost += 0.8
break
return min(boost, 1.6)
def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict:
"""
Hybrid retrieval (embeddings + BM25) with intent-, action-, module-, and phrase-aware reranking.
Returns top items plus doc-level prior and intent for downstream formatting.
"""
norm_query = _normalize_query(query)
q_terms = _tokenize(norm_query)
user_intent = _detect_user_intent(query)
actions = _extract_actions(query)
user_modules = _extract_modules_from_query(query)
# semantic (embeddings) search via Chroma
sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 40))
sem_docs = sem_res.get("documents", [])
sem_metas = sem_res.get("metadatas", [])
sem_dists = sem_res.get("distances", [])
sem_ids = sem_res.get("ids", [])
def dist_to_sim(d: Optional[float]) -> float:
if d is None:
return 0.0
try:
return 1.0 / (1.0 + float(d))
except Exception:
return 0.0
sem_sims = [dist_to_sim(d) for d in sem_dists]
# BM25 search
bm25_hits = bm25_search(norm_query, top_k=max(80, top_k * 6))
bm25_max = max([s for _, s in bm25_hits], default=1.0)
bm25_norm_pairs = [(idx, (score / bm25_max) if bm25_max > 0 else 0.0) for idx, score in bm25_hits]
bm25_id_to_norm, bm25_id_to_text, bm25_id_to_meta = {}, {}, {}
for idx, nscore in bm25_norm_pairs:
d = bm25_docs[idx]
bm25_id_to_norm[d["id"]] = nscore
bm25_id_to_text[d["id"]] = d["text"]
bm25_id_to_meta[d["id"]] = d["meta"]
# union of candidate IDs (semantic + bm25)
union_ids = set(sem_ids) | set(bm25_id_to_norm.keys())
# weights
gamma = 0.30 # meta overlap
delta = 0.55 # intent boost
epsilon = 0.30 # action weight
zeta = 0.65 # module weight
eta = 0.50 # phrase-level boost
theta = 0.00 # optional heading alignment bonus not used
iota = 0.60 # literal query match boost
combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = []
for cid in union_ids:
# pick semantic fields if present; fallback to bm25
if cid in sem_ids:
pos = sem_ids.index(cid)
sem_sim = sem_sims[pos] if pos < len(sem_sims) else 0.0
sem_dist = sem_dists[pos] if pos < len(sem_dists) else None
sem_text = sem_docs[pos] if pos < len(sem_docs) else ""
sem_meta = sem_metas[pos] if pos < len(sem_metas) else {}
else:
sem_sim, sem_dist, sem_text, sem_meta = 0.0, None, "", {}
bm25_sim = bm25_id_to_norm.get(cid, 0.0)
bm25_text = bm25_id_to_text.get(cid, "")
bm25_meta = bm25_id_to_meta.get(cid, {})
text = sem_text if sem_text else bm25_text
meta = sem_meta if sem_meta else bm25_meta
m_overlap = _meta_overlap(meta, q_terms)
intent_boost = _intent_weight(meta, user_intent)
act_wt = _action_weight(text, actions)
mod_wt = _module_weight(meta, user_modules)
phrase_wt = _phrase_boost_score(text, q_terms)
literal_wt = _literal_query_match_boost(text, norm_query)
final_score = (
alpha * sem_sim
+ beta * bm25_sim
+ gamma * m_overlap
+ delta * intent_boost
+ epsilon * act_wt
+ zeta * mod_wt
+ eta * phrase_wt
+ theta * 0.0
+ iota * literal_wt
)
combined_records_ext.append(
(cid, final_score, (sem_dist if sem_dist is not None else 999.0), text, meta,
m_overlap, intent_boost, act_wt, mod_wt, phrase_wt, 0.0, literal_wt)
)
# exact-match rerank for errors (push lines containing query phrases)
if user_intent == "errors":
exact_hits = []
toks = [tok for tok in norm_query.split() if len(tok) > 2]
bigrams = _make_ngrams(toks, 2)
for rec in combined_records_ext:
text_lower = (rec[3] or "").lower()
if norm_query and norm_query in text_lower:
exact_hits.append(rec)
continue
if any(bg in text_lower for bg in bigrams):
exact_hits.append(rec)
if exact_hits:
rest = [r for r in combined_records_ext if r not in exact_hits]
exact_hits.sort(key=lambda x: x[1], reverse=True)
rest.sort(key=lambda x: x[1], reverse=True)
combined_records_ext = exact_hits + rest
# doc-level prior: prefer docs with more aligned chunks
from collections import defaultdict as _dd
doc_groups: Dict[str, List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]] = _dd(list)
for rec in combined_records_ext:
meta = rec[4] or {}
fn = meta.get("filename", "unknown")
doc_groups[fn].append(rec)
def doc_prior(recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]) -> float:
total_score = sum(r[1] for r in recs)
total_overlap = sum(r[5] for r in recs)
total_intent = sum(max(0.0, r[6]) for r in recs)
total_action = sum(max(0.0, r[7]) for r in recs)
total_module = sum(r[8] for r in recs)
total_phrase = sum(r[9] for r in recs)
total_literal = sum(r[11] for r in recs)
total_penalty = sum(min(0.0, r[6]) for r in recs) + sum(min(0.0, r[7]) for r in recs)
errors_section_bonus = 0.0
if any("error" in ((r[4] or {}).get("section", "")).lower() or
"known issues" in ((r[4] or {}).get("section", "")).lower() or
"common issues" in ((r[4] or {}).get("section", "")).lower() for r in recs):
errors_section_bonus = 0.5
return (
total_score
+ 0.4 * total_overlap
+ 0.7 * total_intent
+ 0.5 * total_action
+ 0.8 * total_module
+ 0.6 * total_phrase
+ 0.7 * total_literal
+ errors_section_bonus
+ 0.3 * total_penalty
)
best_doc, best_doc_prior = None, -1.0
for fn, recs in doc_groups.items():
p = doc_prior(recs)
if p > best_doc_prior:
best_doc_prior, best_doc = p, fn
best_recs = sorted(doc_groups.get(best_doc, []), key=lambda x: x[1], reverse=True)
other_recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = []
for fn, recs in doc_groups.items():
if fn == best_doc:
continue
other_recs.extend(recs)
other_recs.sort(key=lambda x: x[1], reverse=True)
reordered = best_recs + other_recs
top = reordered[:top_k]
documents = [t[3] for t in top]
metadatas = [t[4] for t in top]
distances = [t[2] for t in top]
ids = [t[0] for t in top]
combined_scores = [t[1] for t in top]
return {
"documents": documents,
"metadatas": metadatas,
"distances": distances,
"ids": ids,
"combined_scores": combined_scores,
"best_doc": best_doc,
"best_doc_prior": best_doc_prior,
"user_intent": user_intent,
"actions": actions,
}
# ------------------------------ Section fetch helpers ------------------------------
def get_section_text(filename: str, section: str) -> str:
texts: List[str] = []
for d in bm25_docs:
m = d.get("meta", {})
if m.get("filename") == filename and m.get("section") == section:
t = (d.get("text") or "").strip()
if t:
texts.append(t)
return "\n\n".join(texts).strip()
def get_best_steps_section_text(filename: str) -> str:
texts: List[str] = []
for d in bm25_docs:
m = d.get("meta", {})
if m.get("filename") == filename and (m.get("intent_tag") == "steps"):
t = (d.get("text") or "").strip()
if t:
texts.append(t)
return "\n\n".join(texts).strip()
def get_best_errors_section_text(filename: str) -> str:
texts: List[str] = []
for d in bm25_docs:
m = d.get("meta", {})
sec = (m.get("section") or "").lower()
topics = (m.get("topic_tags") or "")
topic_list = [t.strip() for t in topics.split(",") if t.strip()]
if m.get("filename") == filename and (
m.get("intent_tag") == "errors"
or "error" in sec
or "escalation" in sec
or "permission" in sec
or "access" in sec
or "known issues" in sec
or "common issues" in sec
or "errors" in sec
or ("permissions" in topic_list)
):
t = (d.get("text") or "").strip()
if t:
texts.append(t)
return "\n\n".join(texts).strip()
def get_escalation_text(filename: str) -> str:
"""
Return concatenated text from any 'Escalation' section in the given SOP file.
Works across future SOPs—only relies on the heading name containing 'escalation'.
"""
texts: List[str] = []
for d in bm25_docs:
m = d.get("meta", {})
if m.get("filename") == filename:
sec = (m.get("section") or "").lower()
if "escalation" in sec:
t = (d.get("text") or "").strip()
if t:
texts.append(t)
return "\n\n".join(texts).strip()
# ------------------------------ Admin helpers ------------------------------
def get_kb_runtime_info() -> Dict[str, Any]:
return {
"chroma_path": CHROMA_PATH,
"chroma_exists": os.path.isdir(CHROMA_PATH),
"bm25_index_file": BM25_INDEX_FILE,
"bm25_index_exists": os.path.isfile(BM25_INDEX_FILE),
"collection_count": collection.count(),
"bm25_ready": bm25_ready,
}
def reset_kb(folder_path: str) -> Dict[str, Any]:
result = {"status": "OK", "message": "KB reset and re-ingested"}
try:
try:
client.delete_collection(name="knowledge_base")
except Exception:
pass
global collection
collection = client.get_or_create_collection(name="knowledge_base")
try:
if os.path.isfile(BM25_INDEX_FILE):
os.remove(BM25_INDEX_FILE)
except Exception as e:
result.setdefault("warnings", []).append(f"bm25 index delete: {e}")
os.makedirs(CHROMA_PATH, exist_ok=True)
ingest_documents(folder_path)
result["info"] = get_kb_runtime_info()
return result
except Exception as e:
return {"status": "ERROR", "error": f"{e}", "info": get_kb_runtime_info()}