Rifqi Hafizuddin commited on
Commit
e13a901
·
1 Parent(s): d913315

[NOTICKET][DB] menyesuaikan format struktur db_pipeline sesuai dengan file lain

Browse files
src/pipeline/db_pipeline/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from src.pipeline.db_pipeline.pipeline import run_db_pipeline
2
 
3
- __all__ = ["run_db_pipeline"]
 
1
+ from src.pipeline.db_pipeline.db_pipeline_service import DbPipelineService, db_pipeline_service
2
 
3
+ __all__ = ["DbPipelineService", "db_pipeline_service"]
src/pipeline/db_pipeline/connector.py DELETED
@@ -1,74 +0,0 @@
1
- """Connectors for user-provided databases.
2
-
3
- The pipeline does not own user credentials — an API layer (outside this folder)
4
- builds an Engine via `connect(...)` and passes it to `run_db_pipeline`. Use
5
- `engine_scope(...)` for guaranteed disposal of the connection pool.
6
- """
7
-
8
- from contextlib import contextmanager
9
- from typing import Iterator, Literal
10
-
11
- from sqlalchemy import URL, create_engine
12
- from sqlalchemy.engine import Engine
13
-
14
- from src.middlewares.logging import get_logger
15
-
16
- logger = get_logger("db_connector")
17
-
18
- DbType = Literal["postgresql", "mysql", "sqlserver"]
19
-
20
-
21
- def get_postgres_engine(
22
- host: str, port: int, dbname: str, username: str, password: str
23
- ) -> Engine:
24
- """Build a Postgres engine with safe URL escaping (handles special chars in password)."""
25
- url = URL.create(
26
- drivername="postgresql+psycopg2",
27
- username=username,
28
- password=password,
29
- host=host,
30
- port=port,
31
- database=dbname,
32
- )
33
- return create_engine(url)
34
-
35
-
36
- def connect(
37
- db_type: DbType,
38
- host: str,
39
- port: int,
40
- dbname: str,
41
- username: str,
42
- password: str,
43
- ) -> Engine:
44
- """Connect to a user-provided database. Returns a SQLAlchemy engine."""
45
- logger.info("connecting to user db", db_type=db_type, host=host, port=port, dbname=dbname)
46
- if db_type == "postgresql":
47
- return get_postgres_engine(host, port, dbname, username, password)
48
- elif db_type == "sqlserver":
49
- raise NotImplementedError("SQL Server support coming soon")
50
- elif db_type == "mysql":
51
- raise NotImplementedError("MySQL support coming soon")
52
- else:
53
- raise ValueError(f"Unsupported db_type: {db_type}")
54
-
55
-
56
- @contextmanager
57
- def engine_scope(
58
- db_type: DbType,
59
- host: str,
60
- port: int,
61
- dbname: str,
62
- username: str,
63
- password: str,
64
- ) -> Iterator[Engine]:
65
- """Yield a connected Engine and dispose its pool on exit.
66
-
67
- API callers should prefer this over raw `connect(...)` so user DB
68
- connection pools do not leak between pipeline runs.
69
- """
70
- engine = connect(db_type, host, port, dbname, username, password)
71
- try:
72
- yield engine
73
- finally:
74
- engine.dispose()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pipeline/db_pipeline/db_pipeline.py DELETED
@@ -1,68 +0,0 @@
1
- """End-to-end DB ingestion pipeline: introspect user's DB -> profile columns ->
2
- build text -> embed + store in the shared PGVector collection.
3
-
4
- Each column becomes one LangChainDocument with metadata tagging user_id and
5
- source_type='database', so it is retrievable via the existing retriever.
6
- """
7
-
8
- import asyncio
9
- from typing import Optional
10
-
11
- from langchain_core.documents import Document as LangChainDocument
12
- from sqlalchemy.engine import Engine
13
-
14
- from src.db.postgres.vector_store import get_vector_store
15
- from src.middlewares.logging import get_logger
16
- from src.pipeline.db_pipeline.extractor import get_schema, profile_table
17
-
18
- logger = get_logger("db_pipeline")
19
-
20
-
21
- def _to_document(user_id: str, table_name: str, entry: dict) -> LangChainDocument:
22
- col = entry["col"]
23
- return LangChainDocument(
24
- page_content=entry["text"],
25
- metadata={
26
- "user_id": user_id,
27
- "source_type": "database",
28
- "data": {
29
- "table_name": table_name,
30
- "column_name": col["name"],
31
- "column_type": col["type"],
32
- "is_primary_key": col.get("is_primary_key", False),
33
- "foreign_key": col.get("foreign_key"),
34
- },
35
- },
36
- )
37
-
38
-
39
- async def run_db_pipeline(
40
- user_id: str,
41
- engine: Engine,
42
- exclude_tables: Optional[frozenset[str]] = None,
43
- ) -> int:
44
- """Introspect the user's DB, profile columns, embed descriptions, store in PGVector.
45
-
46
- Sync DB work (SQLAlchemy inspect, pandas read_sql) runs in a threadpool;
47
- async vector writes stay on the event loop.
48
-
49
- Returns:
50
- Total number of chunks ingested.
51
- """
52
- vector_store = get_vector_store()
53
- logger.info("db pipeline start", user_id=user_id)
54
-
55
- schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
56
-
57
- total = 0
58
- for table_name, columns in schema.items():
59
- logger.info("profiling table", table=table_name, columns=len(columns))
60
- entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
61
- docs = [_to_document(user_id, table_name, e) for e in entries]
62
- if docs:
63
- await vector_store.aadd_documents(docs)
64
- total += len(docs)
65
- logger.info("ingested chunks", table=table_name, count=len(docs))
66
-
67
- logger.info("db pipeline complete", user_id=user_id, total=total)
68
- return total
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pipeline/db_pipeline/db_pipeline_service.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for ingesting a user's external database into the vector store.
2
+
3
+ End-to-end flow: connect -> introspect schema -> profile columns -> build text
4
+ -> embed + store in the shared PGVector collection (tagged with
5
+ `source_type="database"`, retrievable via the same retriever used for docs).
6
+
7
+ Sync DB work (SQLAlchemy inspect, pandas read_sql) runs in a threadpool;
8
+ async vector writes stay on the event loop.
9
+ """
10
+
11
+ import asyncio
12
+ from contextlib import contextmanager
13
+ from typing import Iterator, Optional
14
+
15
+ from langchain_core.documents import Document as LangChainDocument
16
+ from sqlalchemy import URL, create_engine
17
+ from sqlalchemy.engine import Engine
18
+
19
+ from src.db.postgres.vector_store import get_vector_store
20
+ from src.middlewares.logging import get_logger
21
+ from src.models.credentials import DbType
22
+ from src.pipeline.db_pipeline.extractor import get_schema, profile_table
23
+
24
+ logger = get_logger("db_pipeline")
25
+
26
+
27
+ class DbPipelineService:
28
+ """End-to-end DB ingestion: connect -> introspect -> profile -> embed -> store."""
29
+
30
+ def connect(
31
+ self,
32
+ db_type: DbType,
33
+ host: str,
34
+ port: int,
35
+ database: str,
36
+ username: str,
37
+ password: str,
38
+ ssl_mode: Optional[str] = None,
39
+ ) -> Engine:
40
+ """Build a SQLAlchemy engine for the user's database.
41
+
42
+ Supabase aliases to the Postgres driver (same URL shape). Other
43
+ engines raise NotImplementedError until their connector is added.
44
+
45
+ `ssl_mode` maps to libpq's `sslmode` query param for postgres/supabase
46
+ (required for managed DBs like Neon/Supabase: "require", "verify-ca",
47
+ "verify-full"). Ignored for other db_types until those connectors land.
48
+ """
49
+ logger.info(
50
+ "connecting to user db", db_type=db_type, host=host, port=port, database=database
51
+ )
52
+ if db_type in ("postgres", "supabase"):
53
+ query = {"sslmode": ssl_mode} if ssl_mode else {}
54
+ url = URL.create(
55
+ drivername="postgresql+psycopg2",
56
+ username=username,
57
+ password=password,
58
+ host=host,
59
+ port=port,
60
+ database=database,
61
+ query=query,
62
+ )
63
+ return create_engine(url)
64
+ elif db_type == "mysql":
65
+ raise NotImplementedError("MySQL support coming soon")
66
+ elif db_type == "sqlserver":
67
+ raise NotImplementedError("SQL Server support coming soon")
68
+ elif db_type == "bigquery":
69
+ raise NotImplementedError("BigQuery support coming soon")
70
+ elif db_type == "snowflake":
71
+ raise NotImplementedError("Snowflake support coming soon")
72
+ else:
73
+ raise ValueError(f"Unsupported db_type: {db_type}")
74
+
75
+ @contextmanager
76
+ def engine_scope(
77
+ self,
78
+ db_type: DbType,
79
+ host: str,
80
+ port: int,
81
+ database: str,
82
+ username: str,
83
+ password: str,
84
+ ssl_mode: Optional[str] = None,
85
+ ) -> Iterator[Engine]:
86
+ """Yield a connected Engine and dispose its pool on exit.
87
+
88
+ API callers should prefer this over raw `connect(...)` so user DB
89
+ connection pools do not leak between pipeline runs.
90
+ """
91
+ engine = self.connect(
92
+ db_type, host, port, database, username, password, ssl_mode
93
+ )
94
+ try:
95
+ yield engine
96
+ finally:
97
+ engine.dispose()
98
+
99
+ def _to_document(
100
+ self, user_id: str, table_name: str, entry: dict
101
+ ) -> LangChainDocument:
102
+ col = entry["col"]
103
+ return LangChainDocument(
104
+ page_content=entry["text"],
105
+ metadata={
106
+ "user_id": user_id,
107
+ "source_type": "database",
108
+ "data": {
109
+ "table_name": table_name,
110
+ "column_name": col["name"],
111
+ "column_type": col["type"],
112
+ "is_primary_key": col.get("is_primary_key", False),
113
+ "foreign_key": col.get("foreign_key"),
114
+ },
115
+ },
116
+ )
117
+
118
+ async def run(
119
+ self,
120
+ user_id: str,
121
+ engine: Engine,
122
+ exclude_tables: Optional[frozenset[str]] = None,
123
+ ) -> int:
124
+ """Introspect the user's DB, profile columns, embed descriptions, store in PGVector.
125
+
126
+ Returns:
127
+ Total number of chunks ingested.
128
+ """
129
+ vector_store = get_vector_store()
130
+ logger.info("db pipeline start", user_id=user_id)
131
+
132
+ schema = await asyncio.to_thread(get_schema, engine, exclude_tables)
133
+
134
+ total = 0
135
+ for table_name, columns in schema.items():
136
+ logger.info("profiling table", table=table_name, columns=len(columns))
137
+ entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
138
+ docs = [self._to_document(user_id, table_name, e) for e in entries]
139
+ if docs:
140
+ await vector_store.aadd_documents(docs)
141
+ total += len(docs)
142
+ logger.info("ingested chunks", table=table_name, count=len(docs))
143
+
144
+ logger.info("db pipeline complete", user_id=user_id, total=total)
145
+ return total
146
+
147
+
148
+ db_pipeline_service = DbPipelineService()