Spaces:
Paused
Paused
| """ | |
| ์๊ฒฉ ์ฝ๋ ์คํ ์ต์ ์ด ์ถ๊ฐ๋ ๋ฆฌ๋ญ์ปค ๋ชจ๋ | |
| """ | |
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| from sentence_transformers import CrossEncoder | |
| from langchain.schema import Document | |
| from config import RERANKER_MODEL | |
| class Reranker: | |
| def __init__(self, model_name: str = RERANKER_MODEL): | |
| """ | |
| Cross-Encoder ๋ฆฌ๋ญ์ปค ์ด๊ธฐํ | |
| Args: | |
| model_name: ์ฌ์ฉํ Cross-Encoder ๋ชจ๋ธ ์ด๋ฆ | |
| """ | |
| print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์ค: {model_name}") | |
| # ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ ์ต์ ์ถ๊ฐ | |
| self.model = CrossEncoder( | |
| model_name, | |
| trust_remote_code=True # ์๊ฒฉ ์ฝ๋ ์คํ ํ์ฉ (ํ์) | |
| ) | |
| print(f"๋ฆฌ๋ญ์ปค ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {model_name}") | |
| def rerank(self, query: str, documents: List[Document], top_k: int = 3) -> List[Document]: | |
| """ | |
| ๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ | |
| Args: | |
| query: ๊ฒ์ ์ฟผ๋ฆฌ | |
| documents: ๋ฒกํฐ ๊ฒ์ ๊ฒฐ๊ณผ ๋ฌธ์ ๋ฆฌ์คํธ | |
| top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ | |
| Returns: | |
| ์ฌ์ ๋ ฌ๋ ์์ ๋ฌธ์ ๋ฆฌ์คํธ | |
| """ | |
| if not documents: | |
| return [] | |
| # Cross-Encoder ์ ๋ ฅ ์ ์์ฑ | |
| document_texts = [doc.page_content for doc in documents] | |
| query_doc_pairs = [(query, doc) for doc in document_texts] | |
| # ์ ์ ๊ณ์ฐ | |
| print(f"๋ฆฌ๋ญํน ์ํ ์ค: {len(documents)}๊ฐ ๋ฌธ์") | |
| scores = self.model.predict(query_doc_pairs) | |
| # ์ ์์ ๋ฐ๋ผ ๋ฌธ์ ์ฌ์ ๋ ฌ | |
| doc_score_pairs = list(zip(documents, scores)) | |
| doc_score_pairs.sort(key=lambda x: x[1], reverse=True) | |
| print(f"๋ฆฌ๋ญํน ์๋ฃ: ์์ {top_k}๊ฐ ๋ฌธ์ ์ ํ") | |
| # ์์ k๊ฐ ๊ฒฐ๊ณผ ๋ฐํ | |
| return [doc for doc, score in doc_score_pairs[:top_k]] |