Spaces:
Sleeping
Sleeping
Create routing.py
Browse files- rag/routing.py +58 -0
rag/routing.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
from rag.config import Settings
|
| 5 |
+
from rag.logging_utils import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
_COMMON_FALSE_POSITIVES = {
|
| 10 |
+
"revenue", "income", "balance", "sheet", "notes", "q1", "q2", "q3", "q4",
|
| 11 |
+
"fy", "cash", "assets", "liabilities", "growth", "profit", "loss",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
def _looks_like_company_token(s: str) -> bool:
|
| 15 |
+
s = s.strip().lower()
|
| 16 |
+
if not s or len(s) < 3:
|
| 17 |
+
return False
|
| 18 |
+
if s in _COMMON_FALSE_POSITIVES:
|
| 19 |
+
return False
|
| 20 |
+
# avoid pure numbers / quarter tokens etc
|
| 21 |
+
if re.fullmatch(r"[\d\W_]+", s):
|
| 22 |
+
return False
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
def route_companies(query: str, router_model, settings: Settings) -> Tuple[List[str], List[str]]:
|
| 26 |
+
"""
|
| 27 |
+
Returns (detected_companies_in_scope, out_of_scope_mentions)
|
| 28 |
+
Soft behavior:
|
| 29 |
+
- If out-of-scope is detected, we DON'T hard-fail automatically.
|
| 30 |
+
We'll mention it + still allow broad search, unless user insists.
|
| 31 |
+
"""
|
| 32 |
+
labels = list(settings.router_labels)
|
| 33 |
+
entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
|
| 34 |
+
|
| 35 |
+
detected = set()
|
| 36 |
+
out_of_scope = set()
|
| 37 |
+
|
| 38 |
+
for e in entities:
|
| 39 |
+
text = (e.get("text") or "").strip()
|
| 40 |
+
if not _looks_like_company_token(text):
|
| 41 |
+
continue
|
| 42 |
+
name = text.lower()
|
| 43 |
+
|
| 44 |
+
if "microsoft" in name or "msft" in name:
|
| 45 |
+
detected.add("Microsoft")
|
| 46 |
+
elif "apple" in name or "aapl" in name:
|
| 47 |
+
detected.add("Apple")
|
| 48 |
+
else:
|
| 49 |
+
out_of_scope.add(text)
|
| 50 |
+
|
| 51 |
+
# Also fallback on simple keyword scan (super robust, low cost)
|
| 52 |
+
ql = query.lower()
|
| 53 |
+
if "microsoft" in ql or "msft" in ql:
|
| 54 |
+
detected.add("Microsoft")
|
| 55 |
+
if "apple" in ql or "aapl" in ql:
|
| 56 |
+
detected.add("Apple")
|
| 57 |
+
|
| 58 |
+
return sorted(detected), sorted(out_of_scope)
|