[NOTICKET] add updated_at metadata & prevent duplicate while adding new embedding
#10
by rhbt6767 - opened
src/pipeline/db_pipeline/db_pipeline_service.py
CHANGED
|
@@ -10,12 +10,14 @@ async vector writes stay on the event loop.
|
|
| 10 |
|
| 11 |
import asyncio
|
| 12 |
from contextlib import contextmanager
|
|
|
|
| 13 |
from typing import Any, Iterator, Optional
|
| 14 |
|
| 15 |
from langchain_core.documents import Document as LangChainDocument
|
| 16 |
-
from sqlalchemy import URL, create_engine
|
| 17 |
from sqlalchemy.engine import Engine
|
| 18 |
|
|
|
|
| 19 |
from src.db.postgres.vector_store import get_vector_store
|
| 20 |
from src.middlewares.logging import get_logger
|
| 21 |
from src.models.credentials import DbType
|
|
@@ -146,7 +148,7 @@ class DbPipelineService:
|
|
| 146 |
engine.dispose()
|
| 147 |
|
| 148 |
def _to_document(
|
| 149 |
-
self, user_id: str, table_name: str, entry: dict
|
| 150 |
) -> LangChainDocument:
|
| 151 |
col = entry["col"]
|
| 152 |
return LangChainDocument(
|
|
@@ -154,6 +156,7 @@ class DbPipelineService:
|
|
| 154 |
metadata={
|
| 155 |
"user_id": user_id,
|
| 156 |
"source_type": "database",
|
|
|
|
| 157 |
"data": {
|
| 158 |
"table_name": table_name,
|
| 159 |
"column_name": col["name"],
|
|
@@ -178,13 +181,28 @@ class DbPipelineService:
|
|
| 178 |
vector_store = get_vector_store()
|
| 179 |
logger.info("db pipeline start", user_id=user_id)
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
|
| 182 |
|
|
|
|
| 183 |
total = 0
|
| 184 |
for table_name, columns in schema.items():
|
| 185 |
logger.info("profiling table", table=table_name, columns=len(columns))
|
| 186 |
entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
|
| 187 |
-
docs = [self._to_document(user_id, table_name, e) for e in entries]
|
| 188 |
if docs:
|
| 189 |
await vector_store.aadd_documents(docs)
|
| 190 |
total += len(docs)
|
|
|
|
| 10 |
|
| 11 |
import asyncio
|
| 12 |
from contextlib import contextmanager
|
| 13 |
+
from datetime import datetime, timezone, timedelta
|
| 14 |
from typing import Any, Iterator, Optional
|
| 15 |
|
| 16 |
from langchain_core.documents import Document as LangChainDocument
|
| 17 |
+
from sqlalchemy import URL, create_engine, text
|
| 18 |
from sqlalchemy.engine import Engine
|
| 19 |
|
| 20 |
+
from src.db.postgres.connection import _pgvector_engine
|
| 21 |
from src.db.postgres.vector_store import get_vector_store
|
| 22 |
from src.middlewares.logging import get_logger
|
| 23 |
from src.models.credentials import DbType
|
|
|
|
| 148 |
engine.dispose()
|
| 149 |
|
| 150 |
def _to_document(
|
| 151 |
+
self, user_id: str, table_name: str, entry: dict, updated_at: str
|
| 152 |
) -> LangChainDocument:
|
| 153 |
col = entry["col"]
|
| 154 |
return LangChainDocument(
|
|
|
|
| 156 |
metadata={
|
| 157 |
"user_id": user_id,
|
| 158 |
"source_type": "database",
|
| 159 |
+
"updated_at": updated_at,
|
| 160 |
"data": {
|
| 161 |
"table_name": table_name,
|
| 162 |
"column_name": col["name"],
|
|
|
|
| 181 |
vector_store = get_vector_store()
|
| 182 |
logger.info("db pipeline start", user_id=user_id)
|
| 183 |
|
| 184 |
+
async with _pgvector_engine.begin() as conn:
|
| 185 |
+
result = await conn.execute(
|
| 186 |
+
text(
|
| 187 |
+
"DELETE FROM langchain_pg_embedding "
|
| 188 |
+
"WHERE cmetadata->>'user_id' = :user_id "
|
| 189 |
+
" AND cmetadata->>'source_type' = 'database' "
|
| 190 |
+
" AND collection_id = ("
|
| 191 |
+
" SELECT uuid FROM langchain_pg_collection WHERE name = 'document_embeddings'"
|
| 192 |
+
" )"
|
| 193 |
+
),
|
| 194 |
+
{"user_id": user_id},
|
| 195 |
+
)
|
| 196 |
+
logger.info("cleared old db embeddings", user_id=user_id, deleted=result.rowcount)
|
| 197 |
+
|
| 198 |
schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
|
| 199 |
|
| 200 |
+
updated_at = datetime.now(timezone(timedelta(hours=7))).isoformat()
|
| 201 |
total = 0
|
| 202 |
for table_name, columns in schema.items():
|
| 203 |
logger.info("profiling table", table=table_name, columns=len(columns))
|
| 204 |
entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
|
| 205 |
+
docs = [self._to_document(user_id, table_name, e, updated_at) for e in entries]
|
| 206 |
if docs:
|
| 207 |
await vector_store.aadd_documents(docs)
|
| 208 |
total += len(docs)
|