ishaq101 commited on
Commit
6b29672
·
1 Parent(s): f1f4f28

Feat: Delete embedding when delete document

Browse files
src/document/document_service.py CHANGED
@@ -1,7 +1,8 @@
1
  """Service for managing documents."""
2
 
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
- from sqlalchemy import select, delete
 
5
  from src.db.postgres.models import Document
6
  from src.storage.az_blob.az_blob import blob_storage
7
  from src.middlewares.logging import get_logger
@@ -77,6 +78,20 @@ class DocumentService:
77
  # Delete from blob storage
78
  await blob_storage.delete_file(document.blob_name)
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Delete from database
81
  await db.execute(
82
  delete(Document).where(Document.id == document_id)
 
1
  """Service for managing documents."""
2
 
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
+ from sqlalchemy import select, delete, text
5
+ from src.db.postgres.connection import _pgvector_engine
6
  from src.db.postgres.models import Document
7
  from src.storage.az_blob.az_blob import blob_storage
8
  from src.middlewares.logging import get_logger
 
78
  # Delete from blob storage
79
  await blob_storage.delete_file(document.blob_name)
80
 
81
+ async with _pgvector_engine.begin() as conn:
82
+ await conn.execute(
83
+ text("""
84
+ DELETE FROM langchain_pg_embedding
85
+ WHERE cmetadata->>'user_id' = :user_id
86
+ AND cmetadata->>'source_type' = 'document'
87
+ AND cmetadata->'data'->>'document_id' = :doc_id
88
+ AND collection_id = (
89
+ SELECT uuid FROM langchain_pg_collection WHERE name = 'document_embeddings'
90
+ )
91
+ """),
92
+ {"user_id": document.user_id, "doc_id": document_id},
93
+ )
94
+
95
  # Delete from database
96
  await db.execute(
97
  delete(Document).where(Document.id == document_id)
src/pipeline/db_pipeline/db_pipeline_service.py CHANGED
@@ -112,23 +112,23 @@ class DbPipelineService:
112
  location=credentials.get("location", "US"),
113
  )
114
 
115
- if db_type == "snowflake":
116
- from snowflake.sqlalchemy import URL as SnowflakeURL
117
-
118
- url = SnowflakeURL(
119
- account=credentials["account"],
120
- user=credentials["username"],
121
- password=credentials["password"],
122
- database=credentials["database"],
123
- schema=(
124
- credentials.get("db_schema")
125
- or credentials.get("schema")
126
- or "PUBLIC"
127
- ),
128
- warehouse=credentials["warehouse"],
129
- role=credentials.get("role") or "",
130
- )
131
- return create_engine(url)
132
 
133
  raise NotImplementedError(f"Unsupported db_type: {db_type}")
134
 
 
112
  location=credentials.get("location", "US"),
113
  )
114
 
115
+ # if db_type == "snowflake":
116
+ # from snowflake.sqlalchemy import URL as SnowflakeURL
117
+
118
+ # url = SnowflakeURL(
119
+ # account=credentials["account"],
120
+ # user=credentials["username"],
121
+ # password=credentials["password"],
122
+ # database=credentials["database"],
123
+ # schema=(
124
+ # credentials.get("db_schema")
125
+ # or credentials.get("schema")
126
+ # or "PUBLIC"
127
+ # ),
128
+ # warehouse=credentials["warehouse"],
129
+ # role=credentials.get("role") or "",
130
+ # )
131
+ # return create_engine(url)
132
 
133
  raise NotImplementedError(f"Unsupported db_type: {db_type}")
134