umyunsang's picture
Upload folder using huggingface_hub
9e65b56 verified
"""
CRUD ๋ ˆ์ด์–ด (Unit of Work ํŒจํ„ด).
DocumentSource, IndexingQueue, IndexVersion ํ…Œ์ด๋ธ”์— ๋Œ€ํ•œ
์ƒ์„ฑ/์กฐํšŒ/์ˆ˜์ •/์‚ญ์ œ ํ•จ์ˆ˜๋ฅผ ์ œ๊ณตํ•œ๋‹ค.
๋ชจ๋“  ํ•จ์ˆ˜๋Š” ๋™๊ธฐ Session์„ ์ธ์ž๋กœ ๋ฐ›๋Š”๋‹ค.
์ด ๋ชจ๋“ˆ์˜ ํ•จ์ˆ˜๋“ค์€ ๋‚ด๋ถ€์—์„œ commit์„ ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋Š”๋‹ค.
ํŠธ๋žœ์žญ์…˜์˜ commit/rollback ์ œ์–ด๋Š” caller(์„œ๋น„์Šค ๊ณ„์ธต)์˜ ์ฑ…์ž„์ด๋‹ค.
๋ณตํ•ฉ ์ž‘์—…์˜ ์›์ž์„ฑ์„ ๋ณด์žฅํ•˜๊ธฐ ์œ„ํ•ด flush๋งŒ ์ˆ˜ํ–‰ํ•˜์—ฌ DB์— SQL์„ ์ „์†กํ•˜๋˜,
์ตœ์ข… ํ™•์ •์€ caller๊ฐ€ ๊ฒฐ์ •ํ•œ๋‹ค.
"""
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import func, select, update
from sqlalchemy.orm import Session
from src.inference.db.models import DocumentSource, IndexingQueue, IndexVersion
# ---------------------------------------------------------------------------
# ์ƒ์ˆ˜ ์ •์˜
# ---------------------------------------------------------------------------
MAX_LIMIT = 1000
_ALLOWED_FILTER_COLUMNS = frozenset(
{
"source_type",
"source_id",
"status",
"category",
"source_name",
"embedding_version",
"version",
}
)
_IMMUTABLE_FIELDS = frozenset({"id", "created_at"})
_VALID_QUEUE_STATUSES = frozenset(
{
"pending",
"processing",
"completed",
"skipped",
"failed",
}
)
# ============================================================================
# DocumentSource CRUD
# ============================================================================
def create_document_source(db: Session, **kwargs: Any) -> DocumentSource:
"""์ƒˆ ๋ฌธ์„œ ์›๋ณธ ๋ ˆ์ฝ”๋“œ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค."""
doc = DocumentSource(**kwargs)
db.add(doc)
db.flush()
db.refresh(doc)
return doc
def get_document_source(db: Session, doc_id: uuid.UUID) -> Optional[DocumentSource]:
"""ID๋กœ ๋ฌธ์„œ ์›๋ณธ์„ ์กฐํšŒํ•œ๋‹ค."""
return db.get(DocumentSource, doc_id)
def get_document_sources(
db: Session,
filters: Optional[Dict[str, Any]] = None,
skip: int = 0,
limit: int = 100,
) -> List[DocumentSource]:
"""ํ•„ํ„ฐ ์กฐ๊ฑด์— ๋งž๋Š” ๋ฌธ์„œ ์›๋ณธ ๋ชฉ๋ก์„ ์กฐํšŒํ•œ๋‹ค.
Parameters
----------
filters : dict, optional
์ปฌ๋Ÿผ๋ช…-๊ฐ’ ์Œ์˜ ํ•„ํ„ฐ ๋”•์…”๋„ˆ๋ฆฌ.
์˜ˆ: {"source_type": "case", "status": "active"}
skip : int
๊ฑด๋„ˆ๋›ธ ํ–‰ ์ˆ˜ (ํŽ˜์ด์ง€๋„ค์ด์…˜ ์˜คํ”„์…‹).
limit : int
์ตœ๋Œ€ ๋ฐ˜ํ™˜ ํ–‰ ์ˆ˜.
"""
limit = min(limit, MAX_LIMIT)
stmt = select(DocumentSource)
if filters:
for col_name, value in filters.items():
if col_name in _ALLOWED_FILTER_COLUMNS:
stmt = stmt.where(getattr(DocumentSource, col_name) == value)
stmt = stmt.offset(skip).limit(limit).order_by(DocumentSource.created_at.desc())
return list(db.scalars(stmt).all())
def update_document_source(
db: Session, doc_id: uuid.UUID, **kwargs: Any
) -> Optional[DocumentSource]:
"""๋ฌธ์„œ ์›๋ณธ ๋ ˆ์ฝ”๋“œ๋ฅผ ์ˆ˜์ •ํ•œ๋‹ค.
๋ณ€๊ฒฝํ•  ์ปฌ๋Ÿผ-๊ฐ’์„ kwargs๋กœ ์ „๋‹ฌํ•œ๋‹ค.
"""
doc = db.get(DocumentSource, doc_id)
if doc is None:
return None
for key, value in kwargs.items():
if key in _IMMUTABLE_FIELDS:
continue
if hasattr(doc, key):
setattr(doc, key, value)
db.flush()
db.refresh(doc)
return doc
def delete_document_source(db: Session, doc_id: uuid.UUID) -> bool:
"""๋ฌธ์„œ ์›๋ณธ ๋ ˆ์ฝ”๋“œ๋ฅผ ์‚ญ์ œํ•œ๋‹ค. ์„ฑ๊ณต ์‹œ True ๋ฐ˜ํ™˜."""
doc = db.get(DocumentSource, doc_id)
if doc is None:
return False
db.delete(doc)
db.flush()
return True
def get_by_source_type_and_id(
db: Session, source_type: str, source_id: str
) -> List[DocumentSource]:
"""source_type + source_id ์กฐํ•ฉ์œผ๋กœ ๋ฌธ์„œ๋ฅผ ์กฐํšŒํ•œ๋‹ค.
๋™์ผ ๋ฌธ์„œ์˜ ์—ฌ๋Ÿฌ ์ฒญํฌ๊ฐ€ ๋ฐ˜ํ™˜๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
"""
stmt = (
select(DocumentSource)
.where(
DocumentSource.source_type == source_type,
DocumentSource.source_id == source_id,
)
.order_by(DocumentSource.chunk_index)
)
return list(db.scalars(stmt).all())
# ============================================================================
# IndexingQueue CRUD
# ============================================================================
def create_indexing_queue_item(db: Session, **kwargs: Any) -> IndexingQueue:
"""์ธ๋ฑ์‹ฑ ๋Œ€๊ธฐ์—ด์— ์ƒˆ ํ•ญ๋ชฉ์„ ์ถ”๊ฐ€ํ•œ๋‹ค."""
item = IndexingQueue(**kwargs)
db.add(item)
db.flush()
db.refresh(item)
return item
def get_pending_items(db: Session, limit: int = 50) -> List[IndexingQueue]:
"""pending ์ƒํƒœ์˜ ๋Œ€๊ธฐ์—ด ํ•ญ๋ชฉ์„ ์šฐ์„ ์ˆœ์œ„ ๋‚ด๋ฆผ์ฐจ์ˆœ์œผ๋กœ ์กฐํšŒํ•œ๋‹ค."""
limit = min(limit, MAX_LIMIT)
stmt = (
select(IndexingQueue)
.where(IndexingQueue.status == "pending")
.order_by(IndexingQueue.priority.desc(), IndexingQueue.created_at)
.limit(limit)
)
return list(db.scalars(stmt).all())
def update_queue_status(
db: Session,
item_id: uuid.UUID,
status: str,
skip_reason: Optional[str] = None,
) -> Optional[IndexingQueue]:
"""๋Œ€๊ธฐ์—ด ํ•ญ๋ชฉ์˜ ์ƒํƒœ๋ฅผ ๋ณ€๊ฒฝํ•œ๋‹ค.
completed/failed ์ƒํƒœ๋กœ ๋ณ€๊ฒฝ ์‹œ processed_at์„ ์ž๋™ ์„ค์ •ํ•œ๋‹ค.
"""
if status not in _VALID_QUEUE_STATUSES:
raise ValueError(
f"์œ ํšจํ•˜์ง€ ์•Š์€ ์ƒํƒœ: {status!r}. "
f"ํ—ˆ์šฉ ๊ฐ’: {', '.join(sorted(_VALID_QUEUE_STATUSES))}"
)
item = db.get(IndexingQueue, item_id)
if item is None:
return None
item.status = status
if skip_reason is not None:
item.skip_reason = skip_reason
if status in ("completed", "failed", "skipped"):
item.processed_at = datetime.now(timezone.utc)
db.flush()
db.refresh(item)
return item
def get_queue_stats(db: Session) -> Dict[str, int]:
"""๋Œ€๊ธฐ์—ด ์ƒํƒœ๋ณ„ ๊ฑด์ˆ˜๋ฅผ ์ง‘๊ณ„ํ•œ๋‹ค.
Returns
-------
dict
{"pending": 10, "processing": 2, "completed": 50, ...}
"""
stmt = select(IndexingQueue.status, func.count()).group_by(IndexingQueue.status)
rows = db.execute(stmt).all()
return {status: count for status, count in rows}
# ============================================================================
# IndexVersion CRUD
# ============================================================================
def create_index_version(db: Session, **kwargs: Any) -> IndexVersion:
"""์ƒˆ ์ธ๋ฑ์Šค ๋ฒ„์ „ ๋ ˆ์ฝ”๋“œ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค."""
ver = IndexVersion(**kwargs)
db.add(ver)
db.flush()
db.refresh(ver)
return ver
def get_active_version(db: Session, index_type: str) -> Optional[IndexVersion]:
"""ํŠน์ • index_type์˜ ํ™œ์„ฑ ๋ฒ„์ „์„ ์กฐํšŒํ•œ๋‹ค.
index_type๋ณ„๋กœ active ๋ฒ„์ „์€ ์ตœ๋Œ€ 1๊ฐœ์—ฌ์•ผ ํ•œ๋‹ค.
"""
stmt = (
select(IndexVersion)
.where(
IndexVersion.index_type == index_type,
IndexVersion.is_active.is_(True),
)
.order_by(IndexVersion.built_at.desc())
.limit(1)
)
return db.scalars(stmt).first()
def deactivate_versions(db: Session, index_type: str) -> int:
"""ํŠน์ • index_type์˜ ๋ชจ๋“  ํ™œ์„ฑ ๋ฒ„์ „์„ ๋น„ํ™œ์„ฑํ™”ํ•œ๋‹ค.
์ƒˆ ์ธ๋ฑ์Šค๋ฅผ ํ™œ์„ฑํ™”ํ•˜๊ธฐ ์ „์— ํ˜ธ์ถœํ•˜์—ฌ ๋‹จ์ผ ํ™œ์„ฑ ๋ฒ„์ „์„ ๋ณด์žฅํ•œ๋‹ค.
Returns
-------
int
๋น„ํ™œ์„ฑํ™”๋œ ๋ ˆ์ฝ”๋“œ ์ˆ˜.
"""
stmt = (
update(IndexVersion)
.where(
IndexVersion.index_type == index_type,
IndexVersion.is_active.is_(True),
)
.values(is_active=False)
)
result = db.execute(stmt)
db.flush()
return result.rowcount # type: ignore[return-value]
def activate_version(db: Session, version_id: uuid.UUID) -> Optional[IndexVersion]:
"""ํŠน์ • ์ธ๋ฑ์Šค ๋ฒ„์ „์„ ํ™œ์„ฑํ™”ํ•œ๋‹ค.
๋™์ผ index_type์˜ ๊ธฐ์กด ํ™œ์„ฑ ๋ฒ„์ „์„ ๋จผ์ € ๋น„ํ™œ์„ฑํ™”ํ•œ ๋’ค ๋Œ€์ƒ์„ ํ™œ์„ฑํ™”ํ•œ๋‹ค.
Race Condition ๋ฐฉ์ง€:
SELECT ... FOR UPDATE๋กœ ๋™์ผ index_type์˜ ๋ชจ๋“  ๋ฒ„์ „์— ํ–‰ ๋ ˆ๋ฒจ ์ž ๊ธˆ์„
ํš๋“ํ•œ ๋’ค deactivate/activate๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค. ๋™์‹œ ํ˜ธ์ถœ ์‹œ ํ›„๋ฐœ ํŠธ๋žœ์žญ์…˜์€
์ž ๊ธˆ ํ•ด์ œ๊นŒ์ง€ ๋Œ€๊ธฐํ•˜๋ฏ€๋กœ ๋‹ค์ค‘ active ๋ฒ„์ „์ด ์ƒ๊ธฐ๋Š” ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ•œ๋‹ค.
(PostgreSQL ์ „์šฉ โ€” SQLite๋Š” FOR UPDATE๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š๋Š”๋‹ค.)
"""
ver = db.get(IndexVersion, version_id)
if ver is None:
return None
# ๋™์ผ index_type์˜ ๋ชจ๋“  ๋ฒ„์ „์— ๋Œ€ํ•ด ํ–‰ ๋ ˆ๋ฒจ ์ž ๊ธˆ ํš๋“ (PostgreSQL ์ „์šฉ)
lock_stmt = (
select(IndexVersion).where(IndexVersion.index_type == ver.index_type).with_for_update()
)
db.execute(lock_stmt)
# ์ž ๊ธˆ ํš๋“ ํ›„ ๋™์ผ ํƒ€์ž…์˜ ๊ธฐ์กด ํ™œ์„ฑ ๋ฒ„์ „ ๋น„ํ™œ์„ฑํ™”
deactivate_versions(db, ver.index_type)
ver.is_active = True
db.flush()
db.refresh(ver)
return ver