Rifqi Hafizuddin commited on
Commit ·
15cd3a7
1
Parent(s): 2c8a3e8
[NOTICKET] minor fix in chat.py, add package for query, change schema used to hybrid (cosine+bm25)
Browse files- pyproject.toml +2 -0
- src/api/v1/chat.py +9 -1
- src/rag/retrievers/schema.py +1 -1
pyproject.toml
CHANGED
|
@@ -79,6 +79,8 @@ dependencies = [
|
|
| 79 |
"jsonpatch>=1.33",
|
| 80 |
"pymongo>=4.14.0",
|
| 81 |
"psycopg2>=2.9.11",
|
|
|
|
|
|
|
| 82 |
# --- User-DB connectors (db_pipeline) ---
|
| 83 |
"pymysql>=1.1.1",
|
| 84 |
"pymssql>=2.3.0",
|
|
|
|
| 79 |
"jsonpatch>=1.33",
|
| 80 |
"pymongo>=4.14.0",
|
| 81 |
"psycopg2>=2.9.11",
|
| 82 |
+
# --- SQL parsing / guardrails ---
|
| 83 |
+
"sqlglot>=25.0.0",
|
| 84 |
# --- User-DB connectors (db_pipeline) ---
|
| 85 |
"pymysql>=1.1.1",
|
| 86 |
"pymssql>=2.3.0",
|
src/api/v1/chat.py
CHANGED
|
@@ -61,7 +61,7 @@ def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
| 61 |
seen = set()
|
| 62 |
sources = []
|
| 63 |
for result in results:
|
| 64 |
-
if "document_id" in result["metadata"]
|
| 65 |
meta = result["metadata"]
|
| 66 |
key = (meta.get("data", {}).get("document_id"), meta.get("data", {}).get("page_label"))
|
| 67 |
if key not in seen:
|
|
@@ -182,12 +182,20 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 182 |
|
| 183 |
if not intent_result.get("needs_search"):
|
| 184 |
retrieval_task.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
raw_results = []
|
| 186 |
else:
|
| 187 |
search_query = intent_result.get("search_query", request.message)
|
| 188 |
logger.info(f"Searching for: {search_query}")
|
| 189 |
if search_query != request.message:
|
| 190 |
retrieval_task.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
raw_results = await retriever.retrieve(
|
| 192 |
query=search_query,
|
| 193 |
user_id=request.user_id,
|
|
|
|
| 61 |
seen = set()
|
| 62 |
sources = []
|
| 63 |
for result in results:
|
| 64 |
+
if "document_id" in result["metadata"].get("data", {}):
|
| 65 |
meta = result["metadata"]
|
| 66 |
key = (meta.get("data", {}).get("document_id"), meta.get("data", {}).get("page_label"))
|
| 67 |
if key not in seen:
|
|
|
|
| 182 |
|
| 183 |
if not intent_result.get("needs_search"):
|
| 184 |
retrieval_task.cancel()
|
| 185 |
+
try:
|
| 186 |
+
await retrieval_task
|
| 187 |
+
except asyncio.CancelledError:
|
| 188 |
+
pass
|
| 189 |
raw_results = []
|
| 190 |
else:
|
| 191 |
search_query = intent_result.get("search_query", request.message)
|
| 192 |
logger.info(f"Searching for: {search_query}")
|
| 193 |
if search_query != request.message:
|
| 194 |
retrieval_task.cancel()
|
| 195 |
+
try:
|
| 196 |
+
await retrieval_task
|
| 197 |
+
except asyncio.CancelledError:
|
| 198 |
+
pass
|
| 199 |
raw_results = await retriever.retrieve(
|
| 200 |
query=search_query,
|
| 201 |
user_id=request.user_id,
|
src/rag/retrievers/schema.py
CHANGED
|
@@ -31,7 +31,7 @@ logger = get_logger("schema_retriever")
|
|
| 31 |
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 32 |
|
| 33 |
Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
|
| 34 |
-
ACTIVE_STRATEGY: Strategy = "
|
| 35 |
|
| 36 |
|
| 37 |
class SchemaRetriever(BaseRetriever):
|
|
|
|
| 31 |
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 32 |
|
| 33 |
Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
|
| 34 |
+
ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
|
| 35 |
|
| 36 |
|
| 37 |
class SchemaRetriever(BaseRetriever):
|