ZedLow commited on
Commit
8132eba
·
verified ·
1 Parent(s): a19d371

Update rag/pipeline.py

Browse files
Files changed (1) hide show
  1. rag/pipeline.py +33 -15
rag/pipeline.py CHANGED
@@ -23,41 +23,59 @@ def _route_companies(
23
  "msft": "Microsoft"
24
  }
25
 
26
- labels = list(settings.router_labels)
27
- entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
 
 
 
 
28
 
29
  detected_targets = []
30
  unsupported_targets = []
31
 
32
  for e in entities:
33
- name_clean = (e.get("text") or "").lower().strip()
34
- found_match = False
35
 
36
- for key, canonical_name in allowed_companies.items():
 
 
37
  if key in name_clean:
38
  detected_targets.append(canonical_name)
39
  found_match = True
40
  break
41
 
42
  if not found_match:
43
- unsupported_targets.append(e.get("text"))
44
 
 
 
 
 
 
 
 
45
  detected_targets = list(set(detected_targets))
46
- unsupported_targets = list(set(unsupported_targets))
47
 
 
 
 
 
 
 
 
48
  if unsupported_targets:
 
49
  return [], (
50
- f"⛔ **Out of Scope:** I detected a request for **{', '.join(unsupported_targets)}**. "
51
  "This system only has access to **Microsoft** and **Apple** data."
52
  )
53
 
54
- if not detected_targets:
55
- return [], (
56
- "❓ **Ambiguous Query:** I could not identify a specific company (Apple or Microsoft). "
57
- "Please name the company you want to analyze."
58
- )
59
-
60
- return detected_targets, None
61
 
62
  def _filter_docs(
63
  dataset: List[Dict[str, Any]],
 
23
  "msft": "Microsoft"
24
  }
25
 
26
+ try:
27
+ labels = list(settings.router_labels)
28
+ entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
29
+ except Exception as e:
30
+ print(f"Warning: GLiNER failed with error {e}, falling back to keywords.")
31
+ entities = []
32
 
33
  detected_targets = []
34
  unsupported_targets = []
35
 
36
  for e in entities:
37
+ text = e.get("text", "")
38
+ name_clean = text.lower().strip().strip(".,?!'s")
39
 
40
+ found_match = False
41
+ for key, canonical_name in allowed_companies.items():
42
+
43
  if key in name_clean:
44
  detected_targets.append(canonical_name)
45
  found_match = True
46
  break
47
 
48
  if not found_match:
49
+ unsupported_targets.append(text)
50
 
51
+
52
+ if not detected_targets:
53
+ query_lower = query.lower()
54
+ for key, canonical_name in allowed_companies.items():
55
+ if key in query_lower:
56
+ detected_targets.append(canonical_name)
57
+
58
  detected_targets = list(set(detected_targets))
 
59
 
60
+
61
+
62
+
63
+ if detected_targets:
64
+ return detected_targets, None
65
+
66
+
67
  if unsupported_targets:
68
+ unique_unsupported = list(set(unsupported_targets))
69
  return [], (
70
+ f"⛔ **Out of Scope:** I detected a request for **{', '.join(unique_unsupported)}**. "
71
  "This system only has access to **Microsoft** and **Apple** data."
72
  )
73
 
74
+ # Cas C : Le désert total (vraiment rien trouvé)
75
+ return [], (
76
+ "❓ **Ambiguous Query:** I could not identify a specific company (Apple or Microsoft). "
77
+ "Please name the company explicitly in your question."
78
+ )
 
 
79
 
80
  def _filter_docs(
81
  dataset: List[Dict[str, Any]],