ZedLow commited on
Commit
fb6bf03
·
verified ·
1 Parent(s): ee3f04c

Create routing.py

Browse files
Files changed (1) hide show
  1. rag/routing.py +58 -0
rag/routing.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import re
3
+
4
+ from rag.config import Settings
5
+ from rag.logging_utils import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+ _COMMON_FALSE_POSITIVES = {
10
+ "revenue", "income", "balance", "sheet", "notes", "q1", "q2", "q3", "q4",
11
+ "fy", "cash", "assets", "liabilities", "growth", "profit", "loss",
12
+ }
13
+
14
+ def _looks_like_company_token(s: str) -> bool:
15
+ s = s.strip().lower()
16
+ if not s or len(s) < 3:
17
+ return False
18
+ if s in _COMMON_FALSE_POSITIVES:
19
+ return False
20
+ # avoid pure numbers / quarter tokens etc
21
+ if re.fullmatch(r"[\d\W_]+", s):
22
+ return False
23
+ return True
24
+
25
+ def route_companies(query: str, router_model, settings: Settings) -> Tuple[List[str], List[str]]:
26
+ """
27
+ Returns (detected_companies_in_scope, out_of_scope_mentions)
28
+ Soft behavior:
29
+ - If out-of-scope is detected, we DON'T hard-fail automatically.
30
+ We'll mention it + still allow broad search, unless user insists.
31
+ """
32
+ labels = list(settings.router_labels)
33
+ entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
34
+
35
+ detected = set()
36
+ out_of_scope = set()
37
+
38
+ for e in entities:
39
+ text = (e.get("text") or "").strip()
40
+ if not _looks_like_company_token(text):
41
+ continue
42
+ name = text.lower()
43
+
44
+ if "microsoft" in name or "msft" in name:
45
+ detected.add("Microsoft")
46
+ elif "apple" in name or "aapl" in name:
47
+ detected.add("Apple")
48
+ else:
49
+ out_of_scope.add(text)
50
+
51
+ # Also fallback on simple keyword scan (super robust, low cost)
52
+ ql = query.lower()
53
+ if "microsoft" in ql or "msft" in ql:
54
+ detected.add("Microsoft")
55
+ if "apple" in ql or "aapl" in ql:
56
+ detected.add("Apple")
57
+
58
+ return sorted(detected), sorted(out_of_scope)