| | """ |
| | ๋ฒกํฐ ๊ฒ์ ๊ตฌํ ๋ชจ๋ |
| | """ |
| |
|
| | import os |
| | import numpy as np |
| | from typing import List, Dict, Any, Optional, Union, Tuple |
| | import logging |
| | from sentence_transformers import SentenceTransformer |
| | from .base_retriever import BaseRetriever |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class VectorRetriever(BaseRetriever): |
| | """ |
| | ์๋ฒ ๋ฉ ๊ธฐ๋ฐ ๋ฒกํฐ ๊ฒ์ ๊ตฌํ |
| | """ |
| | |
| | def __init__( |
| | self, |
| | embedding_model: Optional[Union[str, SentenceTransformer]] = "paraphrase-multilingual-MiniLM-L12-v2", |
| | documents: Optional[List[Dict[str, Any]]] = None, |
| | embedding_field: str = "text", |
| | embedding_device: str = "cpu" |
| | ): |
| | """ |
| | VectorRetriever ์ด๊ธฐํ |
| | |
| | Args: |
| | embedding_model: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ด๋ฆ ๋๋ SentenceTransformer ์ธ์คํด์ค |
| | documents: ์ด๊ธฐ ๋ฌธ์ ๋ชฉ๋ก (์ ํ ์ฌํญ) |
| | embedding_field: ์๋ฒ ๋ฉํ ๋ฌธ์ ํ๋ ์ด๋ฆ |
| | embedding_device: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์คํ ์ฅ์น ('cpu' ๋๋ 'cuda') |
| | """ |
| | self.embedding_field = embedding_field |
| | self.model_name = None |
| | |
| | |
| | if isinstance(embedding_model, str): |
| | logger.info(f"์๋ฒ ๋ฉ ๋ชจ๋ธ '{embedding_model}' ๋ก๋ ์ค...") |
| | self.model_name = embedding_model |
| | self.embedding_model = SentenceTransformer(embedding_model, device=embedding_device) |
| | else: |
| | self.embedding_model = embedding_model |
| | |
| | if hasattr(embedding_model, '_modules') and 'modules' in embedding_model._modules: |
| | self.model_name = "loaded_sentence_transformer" |
| | |
| | |
| | self.documents = [] |
| | self.document_embeddings = None |
| | |
| | |
| | if documents: |
| | self.add_documents(documents) |
| | |
| | def add_documents(self, documents: List[Dict[str, Any]]) -> None: |
| | """ |
| | ๊ฒ์๊ธฐ์ ๋ฌธ์๋ฅผ ์ถ๊ฐํ๊ณ ์๋ฒ ๋ฉ ์์ฑ |
| | |
| | Args: |
| | documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก |
| | """ |
| | if not documents: |
| | logger.warning("์ถ๊ฐํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") |
| | return |
| | |
| | |
| | document_texts = [] |
| | for doc in documents: |
| | if self.embedding_field not in doc: |
| | logger.warning(f"๋ฌธ์์ ํ๋ '{self.embedding_field}'๊ฐ ์์ต๋๋ค. ๊ฑด๋๋๋๋ค.") |
| | continue |
| | |
| | self.documents.append(doc) |
| | document_texts.append(doc[self.embedding_field]) |
| | |
| | if not document_texts: |
| | logger.warning(f"์๋ฒ ๋ฉํ ํ
์คํธ๊ฐ ์์ต๋๋ค. ๋ชจ๋ ๋ฌธ์์ '{self.embedding_field}' ํ๋๊ฐ ์๋์ง ํ์ธํ์ธ์.") |
| | return |
| | |
| | |
| | logger.info(f"{len(document_texts)}๊ฐ ๋ฌธ์์ ์๋ฒ ๋ฉ ์์ฑ ์ค...") |
| | new_embeddings = self.embedding_model.encode(document_texts, show_progress_bar=True) |
| | |
| | |
| | if self.document_embeddings is None: |
| | self.document_embeddings = new_embeddings |
| | else: |
| | self.document_embeddings = np.vstack([self.document_embeddings, new_embeddings]) |
| | |
| | logger.info(f"์ด {len(self.documents)}๊ฐ ๋ฌธ์, {self.document_embeddings.shape[0]}๊ฐ ์๋ฒ ๋ฉ ์ ์ฅ๋จ") |
| | |
| | def search(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: |
| | """ |
| | ์ฟผ๋ฆฌ์ ๋ํ ๋ฒกํฐ ๊ฒ์ ์ํ |
| | |
| | Args: |
| | query: ๊ฒ์ ์ฟผ๋ฆฌ |
| | top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ |
| | **kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์ |
| | |
| | Returns: |
| | ๊ด๋ จ์ฑ ์ ์์ ํจ๊ป ๊ฒ์๋ ๋ฌธ์ ๋ชฉ๋ก |
| | """ |
| | if not self.documents or self.document_embeddings is None: |
| | logger.warning("๊ฒ์ํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") |
| | return [] |
| | |
| | |
| | query_embedding = self.embedding_model.encode(query) |
| | |
| | |
| | scores = np.dot(self.document_embeddings, query_embedding) / ( |
| | np.linalg.norm(self.document_embeddings, axis=1) * np.linalg.norm(query_embedding) |
| | ) |
| | |
| | |
| | top_indices = np.argsort(scores)[-top_k:][::-1] |
| | |
| | |
| | results = [] |
| | for idx in top_indices: |
| | doc = self.documents[idx].copy() |
| | doc["score"] = float(scores[idx]) |
| | results.append(doc) |
| | |
| | return results |
| | |
| | def save(self, directory: str) -> None: |
| | """ |
| | ๊ฒ์๊ธฐ ์ํ๋ฅผ ๋์คํฌ์ ์ ์ฅ |
| | |
| | Args: |
| | directory: ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก |
| | """ |
| | import pickle |
| | import json |
| | |
| | os.makedirs(directory, exist_ok=True) |
| | |
| | |
| | with open(os.path.join(directory, "documents.json"), "w", encoding="utf-8") as f: |
| | json.dump(self.documents, f, ensure_ascii=False, indent=2) |
| | |
| | |
| | if self.document_embeddings is not None: |
| | np.save(os.path.join(directory, "embeddings.npy"), self.document_embeddings) |
| | |
| | |
| | model_info = { |
| | "model_name": self.model_name or "paraphrase-multilingual-MiniLM-L12-v2", |
| | "embedding_dim": self.embedding_model.get_sentence_embedding_dimension() if hasattr(self.embedding_model, 'get_sentence_embedding_dimension') else 384 |
| | } |
| | |
| | with open(os.path.join(directory, "model_info.json"), "w") as f: |
| | json.dump(model_info, f) |
| | |
| | logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์ ์ ์ฅํ์ต๋๋ค.") |
| | |
| | @classmethod |
| | def load(cls, directory: str, embedding_model: Optional[Union[str, SentenceTransformer]] = None) -> "VectorRetriever": |
| | """ |
| | ๋์คํฌ์์ ๊ฒ์๊ธฐ ์ํ๋ฅผ ๋ก๋ |
| | |
| | Args: |
| | directory: ๋ก๋ํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก |
| | embedding_model: ์ฌ์ฉํ ์๋ฒ ๋ฉ ๋ชจ๋ธ (์ ๊ณต๋์ง ์์ผ๋ฉด ์ ์ฅ๋ ์ ๋ณด ์ฌ์ฉ) |
| | |
| | Returns: |
| | ๋ก๋๋ VectorRetriever ์ธ์คํด์ค |
| | """ |
| | import json |
| | |
| | |
| | with open(os.path.join(directory, "model_info.json"), "r") as f: |
| | model_info = json.load(f) |
| | |
| | |
| | if embedding_model is None: |
| | |
| | if "model_name" in model_info and isinstance(model_info["model_name"], str): |
| | embedding_model = model_info["model_name"] |
| | else: |
| | |
| | logger.warning("์ ํจํ ๋ชจ๋ธ ์ด๋ฆ์ ์ฐพ์ ์ ์์ต๋๋ค. ๊ธฐ๋ณธ ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.") |
| | embedding_model = "paraphrase-multilingual-MiniLM-L12-v2" |
| | |
| | |
| | retriever = cls(embedding_model=embedding_model) |
| | |
| | |
| | with open(os.path.join(directory, "documents.json"), "r", encoding="utf-8") as f: |
| | retriever.documents = json.load(f) |
| | |
| | |
| | embeddings_path = os.path.join(directory, "embeddings.npy") |
| | if os.path.exists(embeddings_path): |
| | retriever.document_embeddings = np.load(embeddings_path) |
| | |
| | logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์์ ๋ก๋ํ์ต๋๋ค.") |
| | return retriever |
| |
|