Spaces:
Sleeping
Sleeping
Commit
路
bb76787
1
Parent(s):
5514fbd
doc id problem fixed
Browse files- index_retriever.py +12 -16
index_retriever.py
CHANGED
|
@@ -31,20 +31,21 @@ def keyword_filter_nodes(query, nodes, min_keyword_matches=1):
|
|
| 31 |
|
| 32 |
|
| 33 |
def normalize_doc_id(doc_id: str) -> str:
|
| 34 |
-
"""Normalize document ID for
|
| 35 |
doc_id = doc_id.upper().strip()
|
| 36 |
-
doc_id = re.sub(r'
|
| 37 |
doc_id = doc_id.replace("袚袨小孝袪", "袚袨小孝")
|
| 38 |
doc_id = doc_id.replace("GOSTR", "袚袨小孝")
|
| 39 |
return doc_id
|
| 40 |
|
| 41 |
def base_number(doc_id: str) -> str:
|
| 42 |
-
"""Extract
|
| 43 |
-
|
|
|
|
| 44 |
return m.group(1) if m else ""
|
| 45 |
|
| 46 |
-
def filter_nodes_by_doc_id(nodes, doc_ids, threshold=0.
|
| 47 |
-
"""Filter nodes by
|
| 48 |
if not doc_ids:
|
| 49 |
return nodes
|
| 50 |
|
|
@@ -57,22 +58,17 @@ def filter_nodes_by_doc_id(nodes, doc_ids, threshold=0.75):
|
|
| 57 |
node_base = base_number(node_doc_id)
|
| 58 |
|
| 59 |
for q_doc, q_base in zip(doc_ids_norm, doc_ids_base):
|
| 60 |
-
#
|
| 61 |
if q_base and node_base and q_base == node_base:
|
| 62 |
filtered.append(node)
|
| 63 |
break
|
| 64 |
-
|
| 65 |
-
#
|
| 66 |
-
|
| 67 |
-
filtered.append(node)
|
| 68 |
-
break
|
| 69 |
-
|
| 70 |
-
# Weak fallback: contains or partial substring
|
| 71 |
-
if q_base in node_doc_id or q_doc in node_doc_id:
|
| 72 |
filtered.append(node)
|
| 73 |
break
|
| 74 |
|
| 75 |
-
return filtered if filtered else nodes
|
| 76 |
|
| 77 |
|
| 78 |
def extract_doc_id_from_query(query):
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def normalize_doc_id(doc_id: str) -> str:
|
| 34 |
+
"""Normalize document ID - KEEP dots for numeric parts"""
|
| 35 |
doc_id = doc_id.upper().strip()
|
| 36 |
+
doc_id = re.sub(r'\s+', '', doc_id) # Remove spaces only
|
| 37 |
doc_id = doc_id.replace("袚袨小孝袪", "袚袨小孝")
|
| 38 |
doc_id = doc_id.replace("GOSTR", "袚袨小孝")
|
| 39 |
return doc_id
|
| 40 |
|
| 41 |
def base_number(doc_id: str) -> str:
|
| 42 |
+
"""Extract full numeric pattern including all parts (e.g., '59023.6' from '袚袨小孝 59023.6')"""
|
| 43 |
+
# Match: 59023.6 or 59023.4 or 50.05.01 etc.
|
| 44 |
+
m = re.search(r'(\d+(?:\.\d+)*)', doc_id)
|
| 45 |
return m.group(1) if m else ""
|
| 46 |
|
| 47 |
+
def filter_nodes_by_doc_id(nodes, doc_ids, threshold=0.85):
|
| 48 |
+
"""Filter nodes by document ID with strict numeric matching"""
|
| 49 |
if not doc_ids:
|
| 50 |
return nodes
|
| 51 |
|
|
|
|
| 58 |
node_base = base_number(node_doc_id)
|
| 59 |
|
| 60 |
for q_doc, q_base in zip(doc_ids_norm, doc_ids_base):
|
| 61 |
+
# STRICT: base number must match exactly
|
| 62 |
if q_base and node_base and q_base == node_base:
|
| 63 |
filtered.append(node)
|
| 64 |
break
|
| 65 |
+
|
| 66 |
+
# STRICT: full normalized ID must match exactly or have very high similarity
|
| 67 |
+
elif SequenceMatcher(None, node_doc_id, q_doc).ratio() >= threshold:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
filtered.append(node)
|
| 69 |
break
|
| 70 |
|
| 71 |
+
return filtered if filtered else nodes
|
| 72 |
|
| 73 |
|
| 74 |
def extract_doc_id_from_query(query):
|