[NOTICKET] add updated_at metadata & prevent duplicate while adding new embedding

#10
by rhbt6767 - opened
src/pipeline/db_pipeline/db_pipeline_service.py CHANGED
@@ -10,12 +10,14 @@ async vector writes stay on the event loop.
10
 
11
  import asyncio
12
  from contextlib import contextmanager
 
13
  from typing import Any, Iterator, Optional
14
 
15
  from langchain_core.documents import Document as LangChainDocument
16
- from sqlalchemy import URL, create_engine
17
  from sqlalchemy.engine import Engine
18
 
 
19
  from src.db.postgres.vector_store import get_vector_store
20
  from src.middlewares.logging import get_logger
21
  from src.models.credentials import DbType
@@ -146,7 +148,7 @@ class DbPipelineService:
146
  engine.dispose()
147
 
148
  def _to_document(
149
- self, user_id: str, table_name: str, entry: dict
150
  ) -> LangChainDocument:
151
  col = entry["col"]
152
  return LangChainDocument(
@@ -154,6 +156,7 @@ class DbPipelineService:
154
  metadata={
155
  "user_id": user_id,
156
  "source_type": "database",
 
157
  "data": {
158
  "table_name": table_name,
159
  "column_name": col["name"],
@@ -178,13 +181,28 @@ class DbPipelineService:
178
  vector_store = get_vector_store()
179
  logger.info("db pipeline start", user_id=user_id)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
182
 
 
183
  total = 0
184
  for table_name, columns in schema.items():
185
  logger.info("profiling table", table=table_name, columns=len(columns))
186
  entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
187
- docs = [self._to_document(user_id, table_name, e) for e in entries]
188
  if docs:
189
  await vector_store.aadd_documents(docs)
190
  total += len(docs)
 
10
 
11
  import asyncio
12
  from contextlib import contextmanager
13
+ from datetime import datetime, timezone, timedelta
14
  from typing import Any, Iterator, Optional
15
 
16
  from langchain_core.documents import Document as LangChainDocument
17
+ from sqlalchemy import URL, create_engine, text
18
  from sqlalchemy.engine import Engine
19
 
20
+ from src.db.postgres.connection import _pgvector_engine
21
  from src.db.postgres.vector_store import get_vector_store
22
  from src.middlewares.logging import get_logger
23
  from src.models.credentials import DbType
 
148
  engine.dispose()
149
 
150
  def _to_document(
151
+ self, user_id: str, table_name: str, entry: dict, updated_at: str
152
  ) -> LangChainDocument:
153
  col = entry["col"]
154
  return LangChainDocument(
 
156
  metadata={
157
  "user_id": user_id,
158
  "source_type": "database",
159
+ "updated_at": updated_at,
160
  "data": {
161
  "table_name": table_name,
162
  "column_name": col["name"],
 
181
  vector_store = get_vector_store()
182
  logger.info("db pipeline start", user_id=user_id)
183
 
184
+ async with _pgvector_engine.begin() as conn:
185
+ result = await conn.execute(
186
+ text(
187
+ "DELETE FROM langchain_pg_embedding "
188
+ "WHERE cmetadata->>'user_id' = :user_id "
189
+ " AND cmetadata->>'source_type' = 'database' "
190
+ " AND collection_id = ("
191
+ " SELECT uuid FROM langchain_pg_collection WHERE name = 'document_embeddings'"
192
+ " )"
193
+ ),
194
+ {"user_id": user_id},
195
+ )
196
+ logger.info("cleared old db embeddings", user_id=user_id, deleted=result.rowcount)
197
+
198
  schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
199
 
200
+ updated_at = datetime.now(timezone(timedelta(hours=7))).isoformat()
201
  total = 0
202
  for table_name, columns in schema.items():
203
  logger.info("profiling table", table=table_name, columns=len(columns))
204
  entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
205
+ docs = [self._to_document(user_id, table_name, e, updated_at) for e in entries]
206
  if docs:
207
  await vector_store.aadd_documents(docs)
208
  total += len(docs)