Spaces:
Sleeping
Sleeping
| import hashlib | |
| import logging | |
| import os | |
| from typing import Dict, List, Optional | |
| from dotenv import load_dotenv # type: ignore[import] | |
| from qdrant_client import QdrantClient, models | |
| from src.vector_db.local_embeddings import LocalEmbeddingManager | |
| # .env ํ์ผ์์ ํ๊ฒฝ ๋ณ์ ๋ก๋ (๋ก์ปฌ ๊ฐ๋ฐ ํธ์์ฑ) | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| class QdrantManager: | |
| """Qdrant Cloud ๊ธฐ๋ฐ ๋ฒกํฐ ์บ์ ๊ด๋ฆฌ ํด๋์ค. | |
| - ์๋ฒ ๋ฉ ์์ฑ: ๋ก์ปฌ BAAI/bge-m3 | |
| - ๋ฒกํฐ ์ ์ฅ/๊ฒ์: Qdrant Cloud | |
| """ | |
| def __init__(self, collection_name: str = "CodeWeaver") -> None: | |
| """Qdrant Cloud ํด๋ผ์ด์ธํธ๋ฅผ ์ด๊ธฐํํ๊ณ ์ปฌ๋ ์ ์ ์ค๋นํ๋ค.""" | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| if not qdrant_url or not qdrant_api_key: | |
| raise ValueError( | |
| "QDRANT_URL ๋ฐ QDRANT_API_KEY ํ๊ฒฝ ๋ณ์๊ฐ ๋ชจ๋ ์ค์ ๋์ด ์์ด์ผ ํฉ๋๋ค." | |
| ) | |
| # Qdrant Cloud ๊ณต์ ๊ฐ์ด๋์ ์ ์ฌํ ์ด๊ธฐํ ํํ ์ฌ์ฉ | |
| # https://qdrant.tech/documentation/tutorials-and-examples/cloud-inference-hybrid-search/ | |
| self.client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| timeout=30, | |
| ) | |
| self.collection_name = collection_name | |
| self.embedding_manager = LocalEmbeddingManager() | |
| logger.info("QdrantManager ์ด๊ธฐํ: collection=%s, url=%s", collection_name, qdrant_url) | |
| # ์ปฌ๋ ์ ์ด ์๋ค๋ฉด ์์ฑ | |
| self._init_collection() | |
| def _init_collection(self) -> None: | |
| """์ปฌ๋ ์ ์ด ์์ผ๋ฉด ์์ฑํ๋ค.""" | |
| try: | |
| exists = self.client.collection_exists(self.collection_name) | |
| except Exception as e: # pragma: no cover - ๋ฐฉ์ด์ ์ฝ๋ | |
| logger.error("Qdrant ์ปฌ๋ ์ ์กด์ฌ ์ฌ๋ถ ํ์ธ ์คํจ: %s", e, exc_info=True) | |
| raise | |
| if exists: | |
| logger.info("Qdrant ์ปฌ๋ ์ ์ด๋ฏธ ์กด์ฌ: %s", self.collection_name) | |
| return | |
| try: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams( | |
| size=1024, # bge-m3 ์๋ฒ ๋ฉ ์ฐจ์ | |
| distance=models.Distance.COSINE, | |
| ), | |
| ) | |
| logger.info("Qdrant ์ปฌ๋ ์ ์์ฑ ์๋ฃ: %s", self.collection_name) | |
| except Exception as e: | |
| logger.error("Qdrant ์ปฌ๋ ์ ์์ฑ ์คํจ: %s", e, exc_info=True) | |
| raise | |
| async def get_embedding(self, text: str) -> List[float]: | |
| """๋ก์ปฌ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ์ฌ์ฉํด ํ ์คํธ ์๋ฒ ๋ฉ์ ์์ฑํ๋ค.""" | |
| try: | |
| embedding = self.embedding_manager.get_embedding(text) | |
| logger.debug("์๋ฒ ๋ฉ ์์ฑ ์๋ฃ (๊ธธ์ด=%d)", len(embedding)) | |
| return embedding | |
| except Exception as e: | |
| logger.error("์๋ฒ ๋ฉ ์์ฑ ์คํจ: %s", e, exc_info=True) | |
| raise | |
| async def search_cache( | |
| self, | |
| question: str, | |
| threshold: float = 0.85, | |
| ) -> Optional[str]: | |
| """์ง๋ฌธ์ ๋ํ ์บ์๋ ๋ต๋ณ์ Qdrant์์ ๊ฒ์ํ๋ค. | |
| threshold๋ณด๋ค ๋์ score๋ฅผ ๊ฐ์ง ๊ฒฐ๊ณผ๊ฐ ์์ ๋๋ง answer๋ฅผ ๋ฐํํ๋ค. | |
| """ | |
| try: | |
| embedding = await self.get_embedding(question) | |
| except Exception: | |
| # ์ด๋ฏธ get_embedding ๋ด๋ถ์์ ๋ก๊ทธ๋ฅผ ๋จ๊ธฐ๋ฏ๋ก ์ฌ๊ธฐ์๋ ์กฐ์ฉํ ์คํจ ์ฒ๋ฆฌ | |
| return None | |
| try: | |
| # Qdrant ๊ณต์ ๋ฌธ์: query_points๋ฅผ ์ฌ์ฉํ ๋ฒกํฐ ๊ฒ์ | |
| # ๋จ์ผ ๋ฒกํฐ ์ปฌ๋ ์ ์ ๊ฒฝ์ฐ query ํ๋ผ๋ฏธํฐ์ ๋ฒกํฐ ๋ฆฌ์คํธ๋ฅผ ์ง์ ์ ๋ฌ | |
| # https://qdrant.tech/documentation/tutorials-and-examples/cloud-inference-hybrid-search/ | |
| results = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=embedding, # ๋จ์ผ ๋ฒกํฐ ์ปฌ๋ ์ : ๋ฒกํฐ๋ฅผ ์ง์ ์ ๋ฌ | |
| limit=1, | |
| with_payload=True, | |
| ) | |
| except Exception as e: | |
| logger.error("Qdrant ์บ์ ๊ฒ์ ์คํจ: %s", e, exc_info=True) | |
| return None | |
| if not results.points: | |
| logger.info("์บ์ ๋ฏธ์ค: ๊ฒฐ๊ณผ ์์ (question=%s)", question) | |
| return None | |
| top = results.points[0] | |
| score = getattr(top, "score", None) | |
| payload = getattr(top, "payload", {}) or {} | |
| if score is None: | |
| logger.warning("๊ฒ์ ๊ฒฐ๊ณผ์ score๊ฐ ์์ต๋๋ค. payload=%s", payload) | |
| return None | |
| if score < threshold: | |
| logger.info( | |
| "์บ์ ๋ฏธ์ค: score(%.4f) < threshold(%.4f) (question=%s)", | |
| score, | |
| threshold, | |
| question, | |
| ) | |
| return None | |
| answer = payload.get("answer") | |
| if answer is None: | |
| logger.info("์บ์ ํํธ์ด์ง๋ง payload์ answer๊ฐ ์์ต๋๋ค. payload=%s", payload) | |
| return None | |
| logger.info( | |
| "์บ์ ํํธ: score=%.4f, question=%s, answer_length=%d", | |
| score, | |
| question, | |
| len(str(answer)), | |
| ) | |
| return str(answer) | |
| async def save_to_cache(self, question: str, answer: str) -> None: | |
| """์ง๋ฌธ-๋ต๋ณ ์์ Qdrant ์บ์์ ์ ์ฅํ๋ค. | |
| ๋์ผํ ์ง๋ฌธ์ ๋ํด์๋ deterministic ID๋ฅผ ์ฌ์ฉํ์ฌ, | |
| upsert ์ ๊ธฐ์กด ์ํธ๋ฆฌ๋ฅผ ๋ฎ์ด์ฐ๊ฒ ํจ์ผ๋ก์จ ์ค๋ณต์ ๋ฐฉ์งํ๋ค. | |
| """ | |
| try: | |
| embedding = await self.get_embedding(question) | |
| except Exception: | |
| # ์๋ฒ ๋ฉ ์คํจ ์ ์บ์์ ์ ์ฅํ์ง ์๋๋ค. | |
| logger.warning("์๋ฒ ๋ฉ ์คํจ๋ก ์ธํด ์บ์์ ์ ์ฅํ์ง ์์. question=%s", question) | |
| return | |
| # UUID ๋์ ์ง๋ฌธ ํด์ ๊ธฐ๋ฐ deterministic ID ์ฌ์ฉ | |
| # โ ๋์ผ ์ง๋ฌธ = ๋์ผ ID โ upsert๊ฐ ๋ฎ์ด์ฐ๊ธฐ๋ก ๋์ โ ์ค๋ณต ๋ฐฉ์ง | |
| # | |
| # ์ฃผ์: Qdrant point id๋ "unsigned int" ๋๋ "UUID"๋ง ํ์ฉํ๋ค. | |
| # ๋ฐ๋ผ์ sha256 hex(64์)๋ฅผ ๊ทธ๋๋ก ์ฐ์ง ์๊ณ , ์ 32์๋ฅผ UUID ํฌ๋งท์ผ๋ก ๋ณํํด ์ฌ์ฉํ๋ค. | |
| digest = hashlib.sha256(question.encode("utf-8")).hexdigest() | |
| point_id = f"{digest[:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" | |
| # ๊ธฐ์กด ์ํธ๋ฆฌ ์กด์ฌ ์(๋ฎ์ด์ฐ๊ธฐ) ๋ก๊ทธ๋ฅผ ๋จ๊ธด๋ค. ์คํจํด๋ upsert๋ ๊ณ์ ์๋. | |
| try: | |
| existing = self.client.retrieve( | |
| collection_name=self.collection_name, | |
| ids=[point_id], | |
| with_payload=False, | |
| with_vectors=False, | |
| ) | |
| if existing: | |
| logger.info("๊ธฐ์กด ์บ์ ์ํธ๋ฆฌ๋ฅผ ๋ฎ์ด์๋๋ค: point_id=%s", point_id) | |
| except Exception: | |
| pass | |
| point = models.PointStruct( | |
| id=point_id, | |
| vector=embedding, | |
| payload={ | |
| "question": question, | |
| "answer": answer, | |
| }, | |
| ) | |
| try: | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=[point], | |
| ) | |
| logger.info( | |
| "Qdrant ์บ์์ ์ ์ฅ ์๋ฃ (hash ID๋ก ์ค๋ณต ๋ฐฉ์ง): point_id=%s, question_length=%d, answer_length=%d", | |
| point_id, | |
| len(question), | |
| len(answer), | |
| ) | |
| except Exception as e: | |
| logger.error("Qdrant ์บ์ ์ ์ฅ ์คํจ: %s", e, exc_info=True) | |
| async def get_cache_stats(self) -> Dict[str, int]: | |
| """ํ์ฌ ์ปฌ๋ ์ ์ ์บ์ ํต๊ณ๋ฅผ ๋ฐํํ๋ค.""" | |
| try: | |
| info = self.client.get_collection(self.collection_name) | |
| # qdrant_client์ CollectionInfo๋ points_count ์์ฑ์ ์ ๊ณต | |
| points_count = getattr(info, "points_count", 0) or 0 | |
| logger.debug( | |
| "Qdrant ์บ์ ํต๊ณ ์กฐํ: collection=%s, total_entries=%d", | |
| self.collection_name, | |
| points_count, | |
| ) | |
| return {"total_entries": int(points_count)} | |
| except Exception as e: | |
| logger.error("Qdrant ์บ์ ํต๊ณ ์กฐํ ์คํจ: %s", e, exc_info=True) | |
| # ํธ์ถ ์ธก์์ ์๋ฌ ๋ฉ์์ง๋ฅผ ์ฐธ๊ณ ํ ์ ์๋๋ก ํฌํจ | |
| return { | |
| "total_entries": 0, | |
| "error": str(e), # type: ignore[dict-item] | |
| } | |