Spaces:
Sleeping
Sleeping
Update services/kb_creation.py
Browse files- services/kb_creation.py +44 -19
services/kb_creation.py
CHANGED
|
@@ -13,14 +13,13 @@ client = chromadb.PersistentClient(path=CHROMA_PATH)
|
|
| 13 |
collection = client.get_or_create_collection(name="knowledge_base")
|
| 14 |
|
| 15 |
# --------------------------- Embedding model ---------------------------
|
| 16 |
-
# model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2') # optional
|
| 17 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 18 |
|
| 19 |
# --------------------------- BM25 (lightweight) ---------------------------
|
| 20 |
BM25_INDEX_FILE = os.path.join(CHROMA_PATH, "bm25_index.pkl")
|
| 21 |
-
bm25_docs: List[Dict[str, Any]] = []
|
| 22 |
-
bm25_inverted: Dict[str, List[int]] = {}
|
| 23 |
-
bm25_df: Dict[str, int] = {}
|
| 24 |
bm25_avgdl: float = 0.0
|
| 25 |
bm25_ready: bool = False
|
| 26 |
BM25_K1 = 1.5
|
|
@@ -228,12 +227,14 @@ def bm25_search(query: str, top_k: int = 50) -> List[Tuple[int, float]]:
|
|
| 228 |
q_terms = _tokenize(norm)
|
| 229 |
if not q_terms:
|
| 230 |
return []
|
|
|
|
| 231 |
candidates = set()
|
| 232 |
for t in q_terms:
|
| 233 |
for idx in bm25_inverted.get(t, []):
|
| 234 |
candidates.add(idx)
|
| 235 |
if not candidates:
|
| 236 |
candidates = set(range(len(bm25_docs)))
|
|
|
|
| 237 |
scored = []
|
| 238 |
for idx in candidates:
|
| 239 |
s = _bm25_score_for_doc(q_terms, idx)
|
|
@@ -248,7 +249,7 @@ def search_knowledge_base(query: str, top_k: int = 10) -> dict:
|
|
| 248 |
res = collection.query(
|
| 249 |
query_embeddings=[query_embedding],
|
| 250 |
n_results=top_k,
|
| 251 |
-
include=['documents', 'metadatas', 'distances']
|
| 252 |
)
|
| 253 |
docs_ll = res.get("documents", [[]]) or [[]]
|
| 254 |
metas_ll = res.get("metadatas", [[]]) or [[]]
|
|
@@ -329,9 +330,6 @@ def _meta_overlap(meta: Dict[str, Any], q_terms: List[str]) -> float:
|
|
| 329 |
return inter / max(1, len(qset))
|
| 330 |
|
| 331 |
def _action_weight(text: str, actions: List[str]) -> float:
|
| 332 |
-
"""
|
| 333 |
-
Boost if text contains target action verb(s); penalize if text dominated by other actions.
|
| 334 |
-
"""
|
| 335 |
if not actions:
|
| 336 |
return 0.0
|
| 337 |
t = (text or "").lower()
|
|
@@ -339,8 +337,7 @@ def _action_weight(text: str, actions: List[str]) -> float:
|
|
| 339 |
for act in actions:
|
| 340 |
for syn in ACTION_SYNONYMS.get(act, [act]):
|
| 341 |
if syn in t:
|
| 342 |
-
score += 1.0
|
| 343 |
-
# Penalize conflicting actions: e.g., query 'create' but text has 'delete' heavily
|
| 344 |
conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": [], "perform": []}
|
| 345 |
for act in actions:
|
| 346 |
for bad in conflicts.get(act, []):
|
|
@@ -350,14 +347,10 @@ def _action_weight(text: str, actions: List[str]) -> float:
|
|
| 350 |
return score
|
| 351 |
|
| 352 |
def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict:
|
| 353 |
-
"""
|
| 354 |
-
final = alpha * semantic_sim + beta * bm25_norm + gamma * meta_overlap + delta * intent_boost + epsilon * action_weight
|
| 355 |
-
+ document-level voting prior.
|
| 356 |
-
"""
|
| 357 |
norm_query = _normalize_query(query)
|
| 358 |
q_terms = _tokenize(norm_query)
|
| 359 |
user_intent = _detect_user_intent(query)
|
| 360 |
-
actions = _extract_actions(query)
|
| 361 |
|
| 362 |
sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 30))
|
| 363 |
sem_docs = sem_res.get("documents", [])
|
|
@@ -388,9 +381,9 @@ def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6
|
|
| 388 |
|
| 389 |
union_ids = set(sem_ids) | set(bm25_id_to_norm.keys())
|
| 390 |
|
| 391 |
-
gamma = 0.25
|
| 392 |
-
delta = 0.35
|
| 393 |
-
epsilon = 0.30
|
| 394 |
|
| 395 |
combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float]] = []
|
| 396 |
for cid in union_ids:
|
|
@@ -412,7 +405,7 @@ def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6
|
|
| 412 |
|
| 413 |
m_overlap = _meta_overlap(meta, q_terms)
|
| 414 |
intent_boost = _intent_weight(meta, user_intent)
|
| 415 |
-
act_wt = _action_weight(text, actions)
|
| 416 |
|
| 417 |
final_score = alpha * sem_sim + beta * bm25_sim + gamma * m_overlap + delta * intent_boost + epsilon * act_wt
|
| 418 |
|
|
@@ -469,3 +462,35 @@ def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6
|
|
| 469 |
"user_intent": user_intent,
|
| 470 |
"actions": actions,
|
| 471 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
collection = client.get_or_create_collection(name="knowledge_base")
|
| 14 |
|
| 15 |
# --------------------------- Embedding model ---------------------------
|
|
|
|
| 16 |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 17 |
|
| 18 |
# --------------------------- BM25 (lightweight) ---------------------------
|
| 19 |
BM25_INDEX_FILE = os.path.join(CHROMA_PATH, "bm25_index.pkl")
|
| 20 |
+
bm25_docs: List[Dict[str, Any]] = []
|
| 21 |
+
bm25_inverted: Dict[str, List[int]] = {}
|
| 22 |
+
bm25_df: Dict[str, int] = {}
|
| 23 |
bm25_avgdl: float = 0.0
|
| 24 |
bm25_ready: bool = False
|
| 25 |
BM25_K1 = 1.5
|
|
|
|
| 227 |
q_terms = _tokenize(norm)
|
| 228 |
if not q_terms:
|
| 229 |
return []
|
| 230 |
+
|
| 231 |
candidates = set()
|
| 232 |
for t in q_terms:
|
| 233 |
for idx in bm25_inverted.get(t, []):
|
| 234 |
candidates.add(idx)
|
| 235 |
if not candidates:
|
| 236 |
candidates = set(range(len(bm25_docs)))
|
| 237 |
+
|
| 238 |
scored = []
|
| 239 |
for idx in candidates:
|
| 240 |
s = _bm25_score_for_doc(q_terms, idx)
|
|
|
|
| 249 |
res = collection.query(
|
| 250 |
query_embeddings=[query_embedding],
|
| 251 |
n_results=top_k,
|
| 252 |
+
include=['documents', 'metadatas', 'distances']
|
| 253 |
)
|
| 254 |
docs_ll = res.get("documents", [[]]) or [[]]
|
| 255 |
metas_ll = res.get("metadatas", [[]]) or [[]]
|
|
|
|
| 330 |
return inter / max(1, len(qset))
|
| 331 |
|
| 332 |
def _action_weight(text: str, actions: List[str]) -> float:
|
|
|
|
|
|
|
|
|
|
| 333 |
if not actions:
|
| 334 |
return 0.0
|
| 335 |
t = (text or "").lower()
|
|
|
|
| 337 |
for act in actions:
|
| 338 |
for syn in ACTION_SYNONYMS.get(act, [act]):
|
| 339 |
if syn in t:
|
| 340 |
+
score += 1.0
|
|
|
|
| 341 |
conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": [], "perform": []}
|
| 342 |
for act in actions:
|
| 343 |
for bad in conflicts.get(act, []):
|
|
|
|
| 347 |
return score
|
| 348 |
|
| 349 |
def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
norm_query = _normalize_query(query)
|
| 351 |
q_terms = _tokenize(norm_query)
|
| 352 |
user_intent = _detect_user_intent(query)
|
| 353 |
+
actions = _extract_actions(query)
|
| 354 |
|
| 355 |
sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 30))
|
| 356 |
sem_docs = sem_res.get("documents", [])
|
|
|
|
| 381 |
|
| 382 |
union_ids = set(sem_ids) | set(bm25_id_to_norm.keys())
|
| 383 |
|
| 384 |
+
gamma = 0.25
|
| 385 |
+
delta = 0.35
|
| 386 |
+
epsilon = 0.30
|
| 387 |
|
| 388 |
combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float]] = []
|
| 389 |
for cid in union_ids:
|
|
|
|
| 405 |
|
| 406 |
m_overlap = _meta_overlap(meta, q_terms)
|
| 407 |
intent_boost = _intent_weight(meta, user_intent)
|
| 408 |
+
act_wt = _action_weight(text, actions)
|
| 409 |
|
| 410 |
final_score = alpha * sem_sim + beta * bm25_sim + gamma * m_overlap + delta * intent_boost + epsilon * act_wt
|
| 411 |
|
|
|
|
| 462 |
"user_intent": user_intent,
|
| 463 |
"actions": actions,
|
| 464 |
}
|
| 465 |
+
|
| 466 |
+
# --- Runtime info & reset helpers (optional admin) ---
|
| 467 |
+
def get_kb_runtime_info() -> Dict[str, Any]:
|
| 468 |
+
return {
|
| 469 |
+
"chroma_path": CHROMA_PATH,
|
| 470 |
+
"chroma_exists": os.path.isdir(CHROMA_PATH),
|
| 471 |
+
"bm25_index_file": BM25_INDEX_FILE,
|
| 472 |
+
"bm25_index_exists": os.path.isfile(BM25_INDEX_FILE),
|
| 473 |
+
"collection_count": collection.count(),
|
| 474 |
+
"bm25_ready": bm25_ready,
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
def reset_kb(folder_path: str) -> Dict[str, Any]:
|
| 478 |
+
result = {"status": "OK", "message": "KB reset and re-ingested"}
|
| 479 |
+
try:
|
| 480 |
+
try:
|
| 481 |
+
client.delete_collection(name="knowledge_base")
|
| 482 |
+
except Exception:
|
| 483 |
+
pass
|
| 484 |
+
global collection
|
| 485 |
+
collection = client.get_or_create_collection(name="knowledge_base")
|
| 486 |
+
try:
|
| 487 |
+
if os.path.isfile(BM25_INDEX_FILE):
|
| 488 |
+
os.remove(BM25_INDEX_FILE)
|
| 489 |
+
except Exception as e:
|
| 490 |
+
result.setdefault("warnings", []).append(f"bm25 index delete: {e}")
|
| 491 |
+
os.makedirs(CHROMA_PATH, exist_ok=True)
|
| 492 |
+
ingest_documents(folder_path)
|
| 493 |
+
result["info"] = get_kb_runtime_info()
|
| 494 |
+
return result
|
| 495 |
+
except Exception as e:
|
| 496 |
+
return {"status": "ERROR", "error": f"{e}", "info": get_kb_runtime_info()}
|