Spaces:
Paused
Paused
| """ | |
| CRUD ๋ ์ด์ด (Unit of Work ํจํด). | |
| DocumentSource, IndexingQueue, IndexVersion ํ ์ด๋ธ์ ๋ํ | |
| ์์ฑ/์กฐํ/์์ /์ญ์ ํจ์๋ฅผ ์ ๊ณตํ๋ค. | |
| ๋ชจ๋ ํจ์๋ ๋๊ธฐ Session์ ์ธ์๋ก ๋ฐ๋๋ค. | |
| ์ด ๋ชจ๋์ ํจ์๋ค์ ๋ด๋ถ์์ commit์ ์ํํ์ง ์๋๋ค. | |
| ํธ๋์ญ์ ์ commit/rollback ์ ์ด๋ caller(์๋น์ค ๊ณ์ธต)์ ์ฑ ์์ด๋ค. | |
| ๋ณตํฉ ์์ ์ ์์์ฑ์ ๋ณด์ฅํ๊ธฐ ์ํด flush๋ง ์ํํ์ฌ DB์ SQL์ ์ ์กํ๋, | |
| ์ต์ข ํ์ ์ caller๊ฐ ๊ฒฐ์ ํ๋ค. | |
| """ | |
| import uuid | |
| from datetime import datetime, timezone | |
| from typing import Any, Dict, List, Optional | |
| from sqlalchemy import func, select, update | |
| from sqlalchemy.orm import Session | |
| from src.inference.db.models import DocumentSource, IndexingQueue, IndexVersion | |
| # --------------------------------------------------------------------------- | |
| # ์์ ์ ์ | |
| # --------------------------------------------------------------------------- | |
| MAX_LIMIT = 1000 | |
| _ALLOWED_FILTER_COLUMNS = frozenset( | |
| { | |
| "source_type", | |
| "source_id", | |
| "status", | |
| "category", | |
| "source_name", | |
| "embedding_version", | |
| "version", | |
| } | |
| ) | |
| _IMMUTABLE_FIELDS = frozenset({"id", "created_at"}) | |
| _VALID_QUEUE_STATUSES = frozenset( | |
| { | |
| "pending", | |
| "processing", | |
| "completed", | |
| "skipped", | |
| "failed", | |
| } | |
| ) | |
| # ============================================================================ | |
| # DocumentSource CRUD | |
| # ============================================================================ | |
| def create_document_source(db: Session, **kwargs: Any) -> DocumentSource: | |
| """์ ๋ฌธ์ ์๋ณธ ๋ ์ฝ๋๋ฅผ ์์ฑํ๋ค.""" | |
| doc = DocumentSource(**kwargs) | |
| db.add(doc) | |
| db.flush() | |
| db.refresh(doc) | |
| return doc | |
| def get_document_source(db: Session, doc_id: uuid.UUID) -> Optional[DocumentSource]: | |
| """ID๋ก ๋ฌธ์ ์๋ณธ์ ์กฐํํ๋ค.""" | |
| return db.get(DocumentSource, doc_id) | |
| def get_document_sources( | |
| db: Session, | |
| filters: Optional[Dict[str, Any]] = None, | |
| skip: int = 0, | |
| limit: int = 100, | |
| ) -> List[DocumentSource]: | |
| """ํํฐ ์กฐ๊ฑด์ ๋ง๋ ๋ฌธ์ ์๋ณธ ๋ชฉ๋ก์ ์กฐํํ๋ค. | |
| Parameters | |
| ---------- | |
| filters : dict, optional | |
| ์ปฌ๋ผ๋ช -๊ฐ ์์ ํํฐ ๋์ ๋๋ฆฌ. | |
| ์: {"source_type": "case", "status": "active"} | |
| skip : int | |
| ๊ฑด๋๋ธ ํ ์ (ํ์ด์ง๋ค์ด์ ์คํ์ ). | |
| limit : int | |
| ์ต๋ ๋ฐํ ํ ์. | |
| """ | |
| limit = min(limit, MAX_LIMIT) | |
| stmt = select(DocumentSource) | |
| if filters: | |
| for col_name, value in filters.items(): | |
| if col_name in _ALLOWED_FILTER_COLUMNS: | |
| stmt = stmt.where(getattr(DocumentSource, col_name) == value) | |
| stmt = stmt.offset(skip).limit(limit).order_by(DocumentSource.created_at.desc()) | |
| return list(db.scalars(stmt).all()) | |
| def update_document_source( | |
| db: Session, doc_id: uuid.UUID, **kwargs: Any | |
| ) -> Optional[DocumentSource]: | |
| """๋ฌธ์ ์๋ณธ ๋ ์ฝ๋๋ฅผ ์์ ํ๋ค. | |
| ๋ณ๊ฒฝํ ์ปฌ๋ผ-๊ฐ์ kwargs๋ก ์ ๋ฌํ๋ค. | |
| """ | |
| doc = db.get(DocumentSource, doc_id) | |
| if doc is None: | |
| return None | |
| for key, value in kwargs.items(): | |
| if key in _IMMUTABLE_FIELDS: | |
| continue | |
| if hasattr(doc, key): | |
| setattr(doc, key, value) | |
| db.flush() | |
| db.refresh(doc) | |
| return doc | |
| def delete_document_source(db: Session, doc_id: uuid.UUID) -> bool: | |
| """๋ฌธ์ ์๋ณธ ๋ ์ฝ๋๋ฅผ ์ญ์ ํ๋ค. ์ฑ๊ณต ์ True ๋ฐํ.""" | |
| doc = db.get(DocumentSource, doc_id) | |
| if doc is None: | |
| return False | |
| db.delete(doc) | |
| db.flush() | |
| return True | |
| def get_by_source_type_and_id( | |
| db: Session, source_type: str, source_id: str | |
| ) -> List[DocumentSource]: | |
| """source_type + source_id ์กฐํฉ์ผ๋ก ๋ฌธ์๋ฅผ ์กฐํํ๋ค. | |
| ๋์ผ ๋ฌธ์์ ์ฌ๋ฌ ์ฒญํฌ๊ฐ ๋ฐํ๋ ์ ์์ผ๋ฏ๋ก ๋ฆฌ์คํธ๋ฅผ ๋ฐํํ๋ค. | |
| """ | |
| stmt = ( | |
| select(DocumentSource) | |
| .where( | |
| DocumentSource.source_type == source_type, | |
| DocumentSource.source_id == source_id, | |
| ) | |
| .order_by(DocumentSource.chunk_index) | |
| ) | |
| return list(db.scalars(stmt).all()) | |
| # ============================================================================ | |
| # IndexingQueue CRUD | |
| # ============================================================================ | |
| def create_indexing_queue_item(db: Session, **kwargs: Any) -> IndexingQueue: | |
| """์ธ๋ฑ์ฑ ๋๊ธฐ์ด์ ์ ํญ๋ชฉ์ ์ถ๊ฐํ๋ค.""" | |
| item = IndexingQueue(**kwargs) | |
| db.add(item) | |
| db.flush() | |
| db.refresh(item) | |
| return item | |
| def get_pending_items(db: Session, limit: int = 50) -> List[IndexingQueue]: | |
| """pending ์ํ์ ๋๊ธฐ์ด ํญ๋ชฉ์ ์ฐ์ ์์ ๋ด๋ฆผ์ฐจ์์ผ๋ก ์กฐํํ๋ค.""" | |
| limit = min(limit, MAX_LIMIT) | |
| stmt = ( | |
| select(IndexingQueue) | |
| .where(IndexingQueue.status == "pending") | |
| .order_by(IndexingQueue.priority.desc(), IndexingQueue.created_at) | |
| .limit(limit) | |
| ) | |
| return list(db.scalars(stmt).all()) | |
| def update_queue_status( | |
| db: Session, | |
| item_id: uuid.UUID, | |
| status: str, | |
| skip_reason: Optional[str] = None, | |
| ) -> Optional[IndexingQueue]: | |
| """๋๊ธฐ์ด ํญ๋ชฉ์ ์ํ๋ฅผ ๋ณ๊ฒฝํ๋ค. | |
| completed/failed ์ํ๋ก ๋ณ๊ฒฝ ์ processed_at์ ์๋ ์ค์ ํ๋ค. | |
| """ | |
| if status not in _VALID_QUEUE_STATUSES: | |
| raise ValueError( | |
| f"์ ํจํ์ง ์์ ์ํ: {status!r}. " | |
| f"ํ์ฉ ๊ฐ: {', '.join(sorted(_VALID_QUEUE_STATUSES))}" | |
| ) | |
| item = db.get(IndexingQueue, item_id) | |
| if item is None: | |
| return None | |
| item.status = status | |
| if skip_reason is not None: | |
| item.skip_reason = skip_reason | |
| if status in ("completed", "failed", "skipped"): | |
| item.processed_at = datetime.now(timezone.utc) | |
| db.flush() | |
| db.refresh(item) | |
| return item | |
| def get_queue_stats(db: Session) -> Dict[str, int]: | |
| """๋๊ธฐ์ด ์ํ๋ณ ๊ฑด์๋ฅผ ์ง๊ณํ๋ค. | |
| Returns | |
| ------- | |
| dict | |
| {"pending": 10, "processing": 2, "completed": 50, ...} | |
| """ | |
| stmt = select(IndexingQueue.status, func.count()).group_by(IndexingQueue.status) | |
| rows = db.execute(stmt).all() | |
| return {status: count for status, count in rows} | |
| # ============================================================================ | |
| # IndexVersion CRUD | |
| # ============================================================================ | |
| def create_index_version(db: Session, **kwargs: Any) -> IndexVersion: | |
| """์ ์ธ๋ฑ์ค ๋ฒ์ ๋ ์ฝ๋๋ฅผ ์์ฑํ๋ค.""" | |
| ver = IndexVersion(**kwargs) | |
| db.add(ver) | |
| db.flush() | |
| db.refresh(ver) | |
| return ver | |
| def get_active_version(db: Session, index_type: str) -> Optional[IndexVersion]: | |
| """ํน์ index_type์ ํ์ฑ ๋ฒ์ ์ ์กฐํํ๋ค. | |
| index_type๋ณ๋ก active ๋ฒ์ ์ ์ต๋ 1๊ฐ์ฌ์ผ ํ๋ค. | |
| """ | |
| stmt = ( | |
| select(IndexVersion) | |
| .where( | |
| IndexVersion.index_type == index_type, | |
| IndexVersion.is_active.is_(True), | |
| ) | |
| .order_by(IndexVersion.built_at.desc()) | |
| .limit(1) | |
| ) | |
| return db.scalars(stmt).first() | |
| def deactivate_versions(db: Session, index_type: str) -> int: | |
| """ํน์ index_type์ ๋ชจ๋ ํ์ฑ ๋ฒ์ ์ ๋นํ์ฑํํ๋ค. | |
| ์ ์ธ๋ฑ์ค๋ฅผ ํ์ฑํํ๊ธฐ ์ ์ ํธ์ถํ์ฌ ๋จ์ผ ํ์ฑ ๋ฒ์ ์ ๋ณด์ฅํ๋ค. | |
| Returns | |
| ------- | |
| int | |
| ๋นํ์ฑํ๋ ๋ ์ฝ๋ ์. | |
| """ | |
| stmt = ( | |
| update(IndexVersion) | |
| .where( | |
| IndexVersion.index_type == index_type, | |
| IndexVersion.is_active.is_(True), | |
| ) | |
| .values(is_active=False) | |
| ) | |
| result = db.execute(stmt) | |
| db.flush() | |
| return result.rowcount # type: ignore[return-value] | |
| def activate_version(db: Session, version_id: uuid.UUID) -> Optional[IndexVersion]: | |
| """ํน์ ์ธ๋ฑ์ค ๋ฒ์ ์ ํ์ฑํํ๋ค. | |
| ๋์ผ index_type์ ๊ธฐ์กด ํ์ฑ ๋ฒ์ ์ ๋จผ์ ๋นํ์ฑํํ ๋ค ๋์์ ํ์ฑํํ๋ค. | |
| Race Condition ๋ฐฉ์ง: | |
| SELECT ... FOR UPDATE๋ก ๋์ผ index_type์ ๋ชจ๋ ๋ฒ์ ์ ํ ๋ ๋ฒจ ์ ๊ธ์ | |
| ํ๋ํ ๋ค deactivate/activate๋ฅผ ์ํํ๋ค. ๋์ ํธ์ถ ์ ํ๋ฐ ํธ๋์ญ์ ์ | |
| ์ ๊ธ ํด์ ๊น์ง ๋๊ธฐํ๋ฏ๋ก ๋ค์ค active ๋ฒ์ ์ด ์๊ธฐ๋ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๋ค. | |
| (PostgreSQL ์ ์ฉ โ SQLite๋ FOR UPDATE๋ฅผ ์ง์ํ์ง ์๋๋ค.) | |
| """ | |
| ver = db.get(IndexVersion, version_id) | |
| if ver is None: | |
| return None | |
| # ๋์ผ index_type์ ๋ชจ๋ ๋ฒ์ ์ ๋ํด ํ ๋ ๋ฒจ ์ ๊ธ ํ๋ (PostgreSQL ์ ์ฉ) | |
| lock_stmt = ( | |
| select(IndexVersion).where(IndexVersion.index_type == ver.index_type).with_for_update() | |
| ) | |
| db.execute(lock_stmt) | |
| # ์ ๊ธ ํ๋ ํ ๋์ผ ํ์ ์ ๊ธฐ์กด ํ์ฑ ๋ฒ์ ๋นํ์ฑํ | |
| deactivate_versions(db, ver.index_type) | |
| ver.is_active = True | |
| db.flush() | |
| db.refresh(ver) | |
| return ver | |