srilakshu012456 commited on
Commit
0ffd0e8
·
verified ·
1 Parent(s): 14a4368

Update services/kb_creation.py

Browse files
Files changed (1) hide show
  1. services/kb_creation.py +312 -8
services/kb_creation.py CHANGED
@@ -365,14 +365,318 @@ def search_knowledge_base(query: str, top_k: int = 10) -> dict:
365
  "ids": ids,
366
  }
367
 
368
- # ------------------------------ Hybrid search (improved + exact-match rerank) ------------------------------
369
- # (unchanged from your version; omitted for brevity here)
370
- # NOTE: Keep your existing 'hybrid_search_knowledge_base' implementation as-is.
371
- # It already returns best_doc, user_intent, etc.
372
- from collections import defaultdict
373
-
374
- # (Paste your existing hybrid_search_knowledge_base implementation here unchanged.)
375
- # ── For brevity in this reply we keep your original code intact. ──
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  # ------------------------------ Section fetch helpers ------------------------------
378
  def get_section_text(filename: str, section: str) -> str:
 
365
  "ids": ids,
366
  }
367
 
368
+ # ------------------------------ Hybrid search (generic + intent-aware) ------------------------------
369
+ ACTION_SYNONYMS = {
370
+ "create": ["create", "creation", "add", "new", "generate"],
371
+ "update": ["update", "modify", "change", "edit"],
372
+ "delete": ["delete", "remove"],
373
+ "navigate": ["navigate", "go to", "open"],
374
+ }
375
+ ERROR_INTENT_TERMS = [
376
+ "error", "issue", "fail", "not working", "resolution", "fix",
377
+ "permission", "permissions", "access", "no access", "authorization", "authorisation",
378
+ "role", "role mapping", "not authorized", "permission denied", "insufficient privileges",
379
+ "escalation", "escalation path", "access right", "mismatch", "locked", "wrong"
380
+ ]
381
+
382
+ def _detect_user_intent(query: str) -> str:
383
+ q = (query or "").lower()
384
+ if any(k in q for k in ERROR_INTENT_TERMS):
385
+ return "errors"
386
+ if any(k in q for k in ["steps", "procedure", "how to", "navigate", "process", "do", "perform"]):
387
+ return "steps"
388
+ if any(k in q for k in ["pre-requisite", "prerequisites", "requirement", "requirements"]):
389
+ return "prereqs"
390
+ if any(k in q for k in ["purpose", "overview", "introduction"]):
391
+ return "purpose"
392
+ return "neutral"
393
+
394
+ def _extract_actions(query: str) -> List[str]:
395
+ q = (query or "").lower()
396
+ found = []
397
+ for act, syns in ACTION_SYNONYMS.items():
398
+ if any(s in q for s in syns):
399
+ found.append(act)
400
+ return sorted(set(found)) or []
401
+
402
+ def _extract_modules_from_query(query: str) -> List[str]:
403
+ q = (query or "").lower()
404
+ found = []
405
+ for mod, syns in MODULE_VOCAB.items():
406
+ if any(s in q for s in syns):
407
+ found.append(mod)
408
+ return sorted(set(found))
409
+
410
+ def _action_weight(text: str, actions: List[str]) -> float:
411
+ if not actions:
412
+ return 0.0
413
+ t = (text or "").lower()
414
+ score = 0.0
415
+ for act in actions:
416
+ for syn in ACTION_SYNONYMS.get(act, [act]):
417
+ if syn in t:
418
+ score += 1.0
419
+ conflicts = {"create": ["delete"], "delete": ["create"], "update": ["delete"], "navigate": []}
420
+ for act in actions:
421
+ for bad in conflicts.get(act, []):
422
+ for syn in ACTION_SYNONYMS.get(bad, [bad]):
423
+ if syn in t:
424
+ score -= 0.8
425
+ return score
426
+
427
+ def _module_weight(meta: Dict[str, Any], user_modules: List[str]) -> float:
428
+ if not user_modules:
429
+ return 0.0
430
+ raw = (meta or {}).get("module_tags", "") or ""
431
+ doc_modules = [m.strip() for m in raw.split(",") if m.strip()] if isinstance(raw, str) else (raw or [])
432
+ overlap = len(set(user_modules) & set(doc_modules))
433
+ if overlap == 0:
434
+ return -0.8
435
+ return 0.7 * overlap
436
+
437
+ def _intent_weight(meta: dict, user_intent: str) -> float:
438
+ tag = (meta or {}).get("intent_tag", "neutral")
439
+ if user_intent == "neutral":
440
+ return 0.0
441
+ if tag == user_intent:
442
+ return 1.0
443
+ if tag in ["purpose", "prereqs"] and user_intent in ["steps", "errors"]:
444
+ return -0.6
445
+ st = ((meta or {}).get("section", "") or "").lower()
446
+ topics = (meta or {}).get("topic_tags", "") or ""
447
+ topic_list = [t.strip() for t in topics.split(",") if t.strip()]
448
+ if user_intent == "errors" and (
449
+ any(k in st for k in ["common errors", "known issues", "common issues", "errors", "escalation", "permissions", "access"])
450
+ or ("permissions" in topic_list)
451
+ ):
452
+ return 1.10
453
+ if user_intent == "steps" and any(k in st for k in ["process steps", "procedure", "instructions", "workflow"]):
454
+ return 0.75
455
+ return -0.2
456
+
457
+ def _meta_overlap(meta: Dict[str, Any], q_terms: List[str]) -> float:
458
+ fn_tokens = _tokenize_meta_value(meta.get("filename"))
459
+ title_tokens = _tokenize_meta_value(meta.get("title"))
460
+ section_tokens = _tokenize_meta_value(meta.get("section"))
461
+ topic_tokens = _tokenize_meta_value((meta.get("topic_tags") or ""))
462
+ module_tokens = _tokenize_meta_value((meta.get("module_tags") or ""))
463
+ meta_tokens = set(fn_tokens + title_tokens + section_tokens + topic_tokens + module_tokens)
464
+ if not meta_tokens or not q_terms:
465
+ return 0.0
466
+ qset = set(q_terms)
467
+ inter = len(meta_tokens & qset)
468
+ return inter / max(1, len(qset))
469
+
470
+ def _make_ngrams(tokens: List[str], n: int) -> List[str]:
471
+ return [" ".join(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
472
+
473
+ def _phrase_boost_score(text: str, q_terms: List[str]) -> float:
474
+ if not text or not q_terms:
475
+ return 0.0
476
+ low = (text or "").lower()
477
+ bigrams = _make_ngrams(q_terms, 2)
478
+ trigrams = _make_ngrams(q_terms, 3)
479
+ score = 0.0
480
+ for bg in bigrams:
481
+ if bg and bg in low:
482
+ score += 0.40
483
+ for tg in trigrams:
484
+ if tg and tg in low:
485
+ score += 0.70
486
+ return min(score, 2.0)
487
+
488
+ def _literal_query_match_boost(text: str, query_norm: str) -> float:
489
+ t = (text or "").lower()
490
+ q = (query_norm or "").lower()
491
+ boost = 0.0
492
+ if q and q in t:
493
+ boost += 0.8
494
+ toks = [tok for tok in q.split() if len(tok) > 2]
495
+ bigrams = _make_ngrams(toks, 2)
496
+ for bg in bigrams:
497
+ if bg in t:
498
+ boost += 0.8
499
+ break
500
+ return min(boost, 1.6)
501
+
502
+ def hybrid_search_knowledge_base(query: str, top_k: int = 10, alpha: float = 0.6, beta: float = 0.4) -> dict:
503
+ """
504
+ Hybrid retrieval (embeddings + BM25) with intent-, action-, module-, and phrase-aware reranking.
505
+ Returns top items plus doc-level prior and intent for downstream formatting.
506
+ """
507
+ norm_query = _normalize_query(query)
508
+ q_terms = _tokenize(norm_query)
509
+ user_intent = _detect_user_intent(query)
510
+ actions = _extract_actions(query)
511
+ user_modules = _extract_modules_from_query(query)
512
+
513
+ # semantic (embeddings) search via Chroma
514
+ sem_res = search_knowledge_base(norm_query, top_k=max(top_k, 40))
515
+ sem_docs = sem_res.get("documents", [])
516
+ sem_metas = sem_res.get("metadatas", [])
517
+ sem_dists = sem_res.get("distances", [])
518
+ sem_ids = sem_res.get("ids", [])
519
+
520
+ def dist_to_sim(d: Optional[float]) -> float:
521
+ if d is None:
522
+ return 0.0
523
+ try:
524
+ return 1.0 / (1.0 + float(d))
525
+ except Exception:
526
+ return 0.0
527
+ sem_sims = [dist_to_sim(d) for d in sem_dists]
528
+
529
+ # BM25 search
530
+ bm25_hits = bm25_search(norm_query, top_k=max(80, top_k * 6))
531
+ bm25_max = max([s for _, s in bm25_hits], default=1.0)
532
+ bm25_norm_pairs = [(idx, (score / bm25_max) if bm25_max > 0 else 0.0) for idx, score in bm25_hits]
533
+ bm25_id_to_norm, bm25_id_to_text, bm25_id_to_meta = {}, {}, {}
534
+ for idx, nscore in bm25_norm_pairs:
535
+ d = bm25_docs[idx]
536
+ bm25_id_to_norm[d["id"]] = nscore
537
+ bm25_id_to_text[d["id"]] = d["text"]
538
+ bm25_id_to_meta[d["id"]] = d["meta"]
539
+
540
+ # union of candidate IDs (semantic + bm25)
541
+ union_ids = set(sem_ids) | set(bm25_id_to_norm.keys())
542
+
543
+ # weights
544
+ gamma = 0.30 # meta overlap
545
+ delta = 0.55 # intent boost
546
+ epsilon = 0.30 # action weight
547
+ zeta = 0.65 # module weight
548
+ eta = 0.50 # phrase-level boost
549
+ theta = 0.00 # optional heading alignment bonus not used
550
+ iota = 0.60 # literal query match boost
551
+
552
+ combined_records_ext: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = []
553
+
554
+ for cid in union_ids:
555
+ # pick semantic fields if present; fallback to bm25
556
+ if cid in sem_ids:
557
+ pos = sem_ids.index(cid)
558
+ sem_sim = sem_sims[pos] if pos < len(sem_sims) else 0.0
559
+ sem_dist = sem_dists[pos] if pos < len(sem_dists) else None
560
+ sem_text = sem_docs[pos] if pos < len(sem_docs) else ""
561
+ sem_meta = sem_metas[pos] if pos < len(sem_metas) else {}
562
+ else:
563
+ sem_sim, sem_dist, sem_text, sem_meta = 0.0, None, "", {}
564
+
565
+ bm25_sim = bm25_id_to_norm.get(cid, 0.0)
566
+ bm25_text = bm25_id_to_text.get(cid, "")
567
+ bm25_meta = bm25_id_to_meta.get(cid, {})
568
+
569
+ text = sem_text if sem_text else bm25_text
570
+ meta = sem_meta if sem_meta else bm25_meta
571
+
572
+ m_overlap = _meta_overlap(meta, q_terms)
573
+ intent_boost = _intent_weight(meta, user_intent)
574
+ act_wt = _action_weight(text, actions)
575
+ mod_wt = _module_weight(meta, user_modules)
576
+ phrase_wt = _phrase_boost_score(text, q_terms)
577
+ literal_wt = _literal_query_match_boost(text, norm_query)
578
+
579
+ final_score = (
580
+ alpha * sem_sim
581
+ + beta * bm25_sim
582
+ + gamma * m_overlap
583
+ + delta * intent_boost
584
+ + epsilon * act_wt
585
+ + zeta * mod_wt
586
+ + eta * phrase_wt
587
+ + theta * 0.0
588
+ + iota * literal_wt
589
+ )
590
+ combined_records_ext.append(
591
+ (cid, final_score, (sem_dist if sem_dist is not None else 999.0), text, meta,
592
+ m_overlap, intent_boost, act_wt, mod_wt, phrase_wt, 0.0, literal_wt)
593
+ )
594
+
595
+ # exact-match rerank for errors (push lines containing query phrases)
596
+ if user_intent == "errors":
597
+ exact_hits = []
598
+ toks = [tok for tok in norm_query.split() if len(tok) > 2]
599
+ bigrams = _make_ngrams(toks, 2)
600
+ for rec in combined_records_ext:
601
+ text_lower = (rec[3] or "").lower()
602
+ if norm_query and norm_query in text_lower:
603
+ exact_hits.append(rec)
604
+ continue
605
+ if any(bg in text_lower for bg in bigrams):
606
+ exact_hits.append(rec)
607
+ if exact_hits:
608
+ rest = [r for r in combined_records_ext if r not in exact_hits]
609
+ exact_hits.sort(key=lambda x: x[1], reverse=True)
610
+ rest.sort(key=lambda x: x[1], reverse=True)
611
+ combined_records_ext = exact_hits + rest
612
+
613
+ # doc-level prior: prefer docs with more aligned chunks
614
+ from collections import defaultdict as _dd
615
+ doc_groups: Dict[str, List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]] = _dd(list)
616
+ for rec in combined_records_ext:
617
+ meta = rec[4] or {}
618
+ fn = meta.get("filename", "unknown")
619
+ doc_groups[fn].append(rec)
620
+
621
+ def doc_prior(recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]]) -> float:
622
+ total_score = sum(r[1] for r in recs)
623
+ total_overlap = sum(r[5] for r in recs)
624
+ total_intent = sum(max(0.0, r[6]) for r in recs)
625
+ total_action = sum(max(0.0, r[7]) for r in recs)
626
+ total_module = sum(r[8] for r in recs)
627
+ total_phrase = sum(r[9] for r in recs)
628
+ total_literal = sum(r[11] for r in recs)
629
+ total_penalty = sum(min(0.0, r[6]) for r in recs) + sum(min(0.0, r[7]) for r in recs)
630
+ errors_section_bonus = 0.0
631
+ if any("error" in ((r[4] or {}).get("section", "")).lower() or
632
+ "known issues" in ((r[4] or {}).get("section", "")).lower() or
633
+ "common issues" in ((r[4] or {}).get("section", "")).lower() for r in recs):
634
+ errors_section_bonus = 0.5
635
+ return (
636
+ total_score
637
+ + 0.4 * total_overlap
638
+ + 0.7 * total_intent
639
+ + 0.5 * total_action
640
+ + 0.8 * total_module
641
+ + 0.6 * total_phrase
642
+ + 0.7 * total_literal
643
+ + errors_section_bonus
644
+ + 0.3 * total_penalty
645
+ )
646
+
647
+ best_doc, best_doc_prior = None, -1.0
648
+ for fn, recs in doc_groups.items():
649
+ p = doc_prior(recs)
650
+ if p > best_doc_prior:
651
+ best_doc_prior, best_doc = p, fn
652
+
653
+ best_recs = sorted(doc_groups.get(best_doc, []), key=lambda x: x[1], reverse=True)
654
+ other_recs: List[Tuple[str, float, float, str, Dict[str, Any], float, float, float, float, float, float, float]] = []
655
+ for fn, recs in doc_groups.items():
656
+ if fn == best_doc:
657
+ continue
658
+ other_recs.extend(recs)
659
+ other_recs.sort(key=lambda x: x[1], reverse=True)
660
+
661
+ reordered = best_recs + other_recs
662
+ top = reordered[:top_k]
663
+ documents = [t[3] for t in top]
664
+ metadatas = [t[4] for t in top]
665
+ distances = [t[2] for t in top]
666
+ ids = [t[0] for t in top]
667
+ combined_scores = [t[1] for t in top]
668
+
669
+ return {
670
+ "documents": documents,
671
+ "metadatas": metadatas,
672
+ "distances": distances,
673
+ "ids": ids,
674
+ "combined_scores": combined_scores,
675
+ "best_doc": best_doc,
676
+ "best_doc_prior": best_doc_prior,
677
+ "user_intent": user_intent,
678
+ "actions": actions,
679
+ }
680
 
681
  # ------------------------------ Section fetch helpers ------------------------------
682
  def get_section_text(filename: str, section: str) -> str: