| | """ |
| | ์ฌ์์ํ ๊ฒ์ ๊ตฌํ ๋ชจ๋ |
| | """ |
| |
|
| | import logging |
| | from typing import List, Dict, Any, Optional, Union, Callable |
| | from .base_retriever import BaseRetriever |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class ReRanker(BaseRetriever): |
| | """ |
| | ๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์์ํ ๊ฒ์๊ธฐ |
| | """ |
| | |
| | def __init__( |
| | self, |
| | base_retriever: BaseRetriever, |
| | rerank_model: Optional[Union[str, Any]] = None, |
| | rerank_fn: Optional[Callable] = None, |
| | rerank_field: str = "text", |
| | rerank_batch_size: int = 32 |
| | ): |
| | """ |
| | ReRanker ์ด๊ธฐํ |
| | |
| | Args: |
| | base_retriever: ๊ธฐ๋ณธ ๊ฒ์๊ธฐ ์ธ์คํด์ค |
| | rerank_model: ์ฌ์์ํ ๋ชจ๋ธ (Cross-Encoder) ์ด๋ฆ ๋๋ ์ธ์คํด์ค |
| | rerank_fn: ์ฌ์ฉ์ ์ ์ ์ฌ์์ํ ํจ์ (์ ๊ณต๋ ๊ฒฝ์ฐ rerank_model ๋์ ์ฌ์ฉ) |
| | rerank_field: ์ฌ์์ํ์ ์ฌ์ฉํ ๋ฌธ์ ํ๋ |
| | rerank_batch_size: ์ฌ์์ํ ๋ชจ๋ธ ๋ฐฐ์น ํฌ๊ธฐ |
| | """ |
| | self.base_retriever = base_retriever |
| | self.rerank_field = rerank_field |
| | self.rerank_batch_size = rerank_batch_size |
| | self.rerank_fn = rerank_fn |
| | |
| | |
| | if rerank_fn is None and rerank_model is not None: |
| | try: |
| | from sentence_transformers import CrossEncoder |
| | if isinstance(rerank_model, str): |
| | logger.info(f"์ฌ์์ํ ๋ชจ๋ธ '{rerank_model}' ๋ก๋ ์ค...") |
| | self.rerank_model = CrossEncoder(rerank_model) |
| | else: |
| | self.rerank_model = rerank_model |
| | except ImportError: |
| | logger.warning("sentence-transformers ํจํค์ง๊ฐ ์ค์น๋์ง ์์์ต๋๋ค. pip install sentence-transformers ๋ช
๋ น์ผ๋ก ์ค์นํ์ธ์.") |
| | raise |
| | else: |
| | self.rerank_model = None |
| | |
| | def add_documents(self, documents: List[Dict[str, Any]]) -> None: |
| | """ |
| | ๊ธฐ๋ณธ ๊ฒ์๊ธฐ์ ๋ฌธ์ ์ถ๊ฐ |
| | |
| | Args: |
| | documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก |
| | """ |
| | self.base_retriever.add_documents(documents) |
| | |
| | def search(self, query: str, top_k: int = 5, first_stage_k: int = 30, **kwargs) -> List[Dict[str, Any]]: |
| | """ |
| | 2๋จ๊ณ ๊ฒ์ ์ํ: ๊ธฐ๋ณธ ๊ฒ์ + ์ฌ์์ํ |
| | |
| | Args: |
| | query: ๊ฒ์ ์ฟผ๋ฆฌ |
| | top_k: ์ต์ข
์ ์ผ๋ก ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ |
| | first_stage_k: ์ฒซ ๋ฒ์งธ ๋จ๊ณ์์ ๊ฒ์ํ ๊ฒฐ๊ณผ ์ |
| | **kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์ |
| | |
| | Returns: |
| | ์ฌ์์ํ๋ ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก |
| | """ |
| | |
| | logger.info(f"๊ธฐ๋ณธ ๊ฒ์๊ธฐ๋ก {first_stage_k}๊ฐ ๋ฌธ์ ๊ฒ์ ์ค...") |
| | initial_results = self.base_retriever.search(query, top_k=first_stage_k, **kwargs) |
| | |
| | if not initial_results: |
| | logger.warning("์ฒซ ๋ฒ์งธ ๋จ๊ณ ๊ฒ์ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค.") |
| | return [] |
| | |
| | if len(initial_results) < first_stage_k: |
| | logger.info(f"์์ฒญํ {first_stage_k}๊ฐ๋ณด๋ค ์ ์ {len(initial_results)}๊ฐ ๊ฒฐ๊ณผ๋ฅผ ๊ฒ์ํ์ต๋๋ค.") |
| | |
| | |
| | if self.rerank_fn is not None: |
| | logger.info("์ฌ์ฉ์ ์ ์ ํจ์๋ก ์ฌ์์ํ ์ค...") |
| | reranked_results = self.rerank_fn(query, initial_results) |
| | return reranked_results[:top_k] |
| | |
| | |
| | elif self.rerank_model is not None: |
| | logger.info(f"CrossEncoder ๋ชจ๋ธ๋ก ์ฌ์์ํ ์ค...") |
| | |
| | |
| | text_pairs = [] |
| | for doc in initial_results: |
| | if self.rerank_field not in doc: |
| | logger.warning(f"๋ฌธ์์ ํ๋ '{self.rerank_field}'๊ฐ ์์ต๋๋ค.") |
| | continue |
| | text_pairs.append([query, doc[self.rerank_field]]) |
| | |
| | |
| | scores = self.rerank_model.predict( |
| | text_pairs, |
| | batch_size=self.rerank_batch_size, |
| | show_progress_bar=True if len(text_pairs) > 10 else False |
| | ) |
| | |
| | |
| | for idx, doc in enumerate(initial_results[:len(scores)]): |
| | doc["rerank_score"] = float(scores[idx]) |
| | |
| | reranked_results = sorted( |
| | initial_results[:len(scores)], |
| | key=lambda x: x.get("rerank_score", 0), |
| | reverse=True |
| | ) |
| | |
| | return reranked_results[:top_k] |
| | |
| | |
| | else: |
| | logger.info("์ฌ์์ํ ๋ชจ๋ธ/ํจ์๊ฐ ์์ด ์ด๊ธฐ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋๋ก ๋ฐํํฉ๋๋ค.") |
| | return initial_results[:top_k] |
| |
|