Feat: Delete embedding when delete document
Browse files
src/document/document_service.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""Service for managing documents."""
|
| 2 |
|
| 3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
-
from sqlalchemy import select, delete
|
|
|
|
| 5 |
from src.db.postgres.models import Document
|
| 6 |
from src.storage.az_blob.az_blob import blob_storage
|
| 7 |
from src.middlewares.logging import get_logger
|
|
@@ -77,6 +78,20 @@ class DocumentService:
|
|
| 77 |
# Delete from blob storage
|
| 78 |
await blob_storage.delete_file(document.blob_name)
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Delete from database
|
| 81 |
await db.execute(
|
| 82 |
delete(Document).where(Document.id == document_id)
|
|
|
|
| 1 |
"""Service for managing documents."""
|
| 2 |
|
| 3 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
+
from sqlalchemy import select, delete, text
|
| 5 |
+
from src.db.postgres.connection import _pgvector_engine
|
| 6 |
from src.db.postgres.models import Document
|
| 7 |
from src.storage.az_blob.az_blob import blob_storage
|
| 8 |
from src.middlewares.logging import get_logger
|
|
|
|
| 78 |
# Delete from blob storage
|
| 79 |
await blob_storage.delete_file(document.blob_name)
|
| 80 |
|
| 81 |
+
async with _pgvector_engine.begin() as conn:
|
| 82 |
+
await conn.execute(
|
| 83 |
+
text("""
|
| 84 |
+
DELETE FROM langchain_pg_embedding
|
| 85 |
+
WHERE cmetadata->>'user_id' = :user_id
|
| 86 |
+
AND cmetadata->>'source_type' = 'document'
|
| 87 |
+
AND cmetadata->'data'->>'document_id' = :doc_id
|
| 88 |
+
AND collection_id = (
|
| 89 |
+
SELECT uuid FROM langchain_pg_collection WHERE name = 'document_embeddings'
|
| 90 |
+
)
|
| 91 |
+
"""),
|
| 92 |
+
{"user_id": document.user_id, "doc_id": document_id},
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
# Delete from database
|
| 96 |
await db.execute(
|
| 97 |
delete(Document).where(Document.id == document_id)
|
src/pipeline/db_pipeline/db_pipeline_service.py
CHANGED
|
@@ -112,23 +112,23 @@ class DbPipelineService:
|
|
| 112 |
location=credentials.get("location", "US"),
|
| 113 |
)
|
| 114 |
|
| 115 |
-
if db_type == "snowflake":
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
|
| 133 |
raise NotImplementedError(f"Unsupported db_type: {db_type}")
|
| 134 |
|
|
|
|
| 112 |
location=credentials.get("location", "US"),
|
| 113 |
)
|
| 114 |
|
| 115 |
+
# if db_type == "snowflake":
|
| 116 |
+
# from snowflake.sqlalchemy import URL as SnowflakeURL
|
| 117 |
+
|
| 118 |
+
# url = SnowflakeURL(
|
| 119 |
+
# account=credentials["account"],
|
| 120 |
+
# user=credentials["username"],
|
| 121 |
+
# password=credentials["password"],
|
| 122 |
+
# database=credentials["database"],
|
| 123 |
+
# schema=(
|
| 124 |
+
# credentials.get("db_schema")
|
| 125 |
+
# or credentials.get("schema")
|
| 126 |
+
# or "PUBLIC"
|
| 127 |
+
# ),
|
| 128 |
+
# warehouse=credentials["warehouse"],
|
| 129 |
+
# role=credentials.get("role") or "",
|
| 130 |
+
# )
|
| 131 |
+
# return create_engine(url)
|
| 132 |
|
| 133 |
raise NotImplementedError(f"Unsupported db_type: {db_type}")
|
| 134 |
|