Rifqi Hafizuddin commited on
Commit ·
e49db60
1
Parent(s): 8218650
[NOTICKET] add db_client for querying
Browse files
src/api/v1/db_client.py
CHANGED
|
@@ -458,7 +458,7 @@ async def ingest_database_client(
|
|
| 458 |
db_type=client.db_type,
|
| 459 |
credentials=creds,
|
| 460 |
) as engine:
|
| 461 |
-
total = await db_pipeline_service.run(user_id=user_id, engine=engine)
|
| 462 |
except NotImplementedError as e:
|
| 463 |
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
|
| 464 |
except Exception as e:
|
|
|
|
| 458 |
db_type=client.db_type,
|
| 459 |
credentials=creds,
|
| 460 |
) as engine:
|
| 461 |
+
total = await db_pipeline_service.run(user_id=user_id, client_id=client_id, engine=engine)
|
| 462 |
except NotImplementedError as e:
|
| 463 |
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
|
| 464 |
except Exception as e:
|
src/pipeline/db_pipeline/db_pipeline_service.py
CHANGED
|
@@ -148,7 +148,7 @@ class DbPipelineService:
|
|
| 148 |
engine.dispose()
|
| 149 |
|
| 150 |
def _to_document(
|
| 151 |
-
self, user_id: str, table_name: str, entry: dict, updated_at: str
|
| 152 |
) -> LangChainDocument:
|
| 153 |
col = entry["col"]
|
| 154 |
return LangChainDocument(
|
|
@@ -156,6 +156,7 @@ class DbPipelineService:
|
|
| 156 |
metadata={
|
| 157 |
"user_id": user_id,
|
| 158 |
"source_type": "database",
|
|
|
|
| 159 |
"updated_at": updated_at,
|
| 160 |
"data": {
|
| 161 |
"table_name": table_name,
|
|
@@ -170,6 +171,7 @@ class DbPipelineService:
|
|
| 170 |
async def run(
|
| 171 |
self,
|
| 172 |
user_id: str,
|
|
|
|
| 173 |
engine: Engine,
|
| 174 |
exclude_tables: Optional[frozenset[str]] = None,
|
| 175 |
) -> int:
|
|
@@ -202,7 +204,7 @@ class DbPipelineService:
|
|
| 202 |
for table_name, columns in schema.items():
|
| 203 |
logger.info("profiling table", table=table_name, columns=len(columns))
|
| 204 |
entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
|
| 205 |
-
docs = [self._to_document(user_id, table_name, e, updated_at) for e in entries]
|
| 206 |
if docs:
|
| 207 |
await vector_store.aadd_documents(docs)
|
| 208 |
total += len(docs)
|
|
|
|
| 148 |
engine.dispose()
|
| 149 |
|
| 150 |
def _to_document(
|
| 151 |
+
self, user_id: str, client_id: str, table_name: str, entry: dict, updated_at: str
|
| 152 |
) -> LangChainDocument:
|
| 153 |
col = entry["col"]
|
| 154 |
return LangChainDocument(
|
|
|
|
| 156 |
metadata={
|
| 157 |
"user_id": user_id,
|
| 158 |
"source_type": "database",
|
| 159 |
+
"database_client_id": client_id,
|
| 160 |
"updated_at": updated_at,
|
| 161 |
"data": {
|
| 162 |
"table_name": table_name,
|
|
|
|
| 171 |
async def run(
|
| 172 |
self,
|
| 173 |
user_id: str,
|
| 174 |
+
client_id: str,
|
| 175 |
engine: Engine,
|
| 176 |
exclude_tables: Optional[frozenset[str]] = None,
|
| 177 |
) -> int:
|
|
|
|
| 204 |
for table_name, columns in schema.items():
|
| 205 |
logger.info("profiling table", table=table_name, columns=len(columns))
|
| 206 |
entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
|
| 207 |
+
docs = [self._to_document(user_id, client_id, table_name, e, updated_at) for e in entries]
|
| 208 |
if docs:
|
| 209 |
await vector_store.aadd_documents(docs)
|
| 210 |
total += len(docs)
|