|
|
from langchain_chroma import Chroma |
|
|
from langchain_openai.embeddings import OpenAIEmbeddings |
|
|
from langsmith import traceable |
|
|
import time |
|
|
import os |
|
|
from rank_bm25 import BM25Okapi |
|
|
import numpy as np |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
from src.utils.config import RAGConfig |
|
|
|
|
|
|
|
|
class RAGRetriever: |
|
|
"""RAG ๊ฒ์ ์์คํ
(Hybrid Search + Re-ranker)""" |
|
|
|
|
|
def __init__(self, config: RAGConfig = None): |
|
|
self.config = config or RAGConfig() |
|
|
self.vectorstore = None |
|
|
self.embeddings = None |
|
|
|
|
|
self._initialize_embeddings() |
|
|
self._create_vectorstore() |
|
|
self._initialize_bm25() |
|
|
self._initialize_reranker() |
|
|
|
|
|
def _initialize_embeddings(self): |
|
|
"""์๋ฒ ๋ฉ ๋ชจ๋ธ ์ด๊ธฐํ""" |
|
|
os.environ["OPENAI_API_KEY"] = self.config.OPENAI_API_KEY |
|
|
|
|
|
self.embeddings = OpenAIEmbeddings( |
|
|
model=self.config.EMBEDDING_MODEL_NAME |
|
|
) |
|
|
|
|
|
def _create_vectorstore(self): |
|
|
"""๊ธฐ์กด ๋ฒกํฐ์คํ ์ด ๋ก๋""" |
|
|
self.vectorstore = Chroma( |
|
|
embedding_function=self.embeddings, |
|
|
persist_directory=self.config.DB_DIRECTORY, |
|
|
collection_name=self.config.COLLECTION_NAME |
|
|
) |
|
|
|
|
|
def _initialize_bm25(self): |
|
|
"""BM25 ์ธ๋ฑ์ค ์์ฑ""" |
|
|
all_docs = self.vectorstore.get() |
|
|
|
|
|
self.doc_texts = all_docs['documents'] |
|
|
self.doc_ids = all_docs['ids'] |
|
|
self.doc_metadatas = all_docs['metadatas'] |
|
|
|
|
|
self.content_to_id = {text: doc_id for text, doc_id in zip(self.doc_texts, self.doc_ids)} |
|
|
|
|
|
tokenized_docs = [doc.split() for doc in self.doc_texts] |
|
|
self.bm25 = BM25Okapi(tokenized_docs) |
|
|
|
|
|
print(f"โ
BM25 ์ธ๋ฑ์ค ์์ฑ ์๋ฃ: {len(self.doc_texts)}๊ฐ ๋ฌธ์") |
|
|
|
|
|
def _initialize_reranker(self): |
|
|
"""Re-ranker ์ด๊ธฐํ""" |
|
|
self.reranker = CrossEncoder('BAAI/bge-reranker-base') |
|
|
print("โ
Re-ranker ์ด๊ธฐํ ์๋ฃ (bge-reranker-base)") |
|
|
|
|
|
@staticmethod |
|
|
def _min_max_normalize(scores): |
|
|
"""0~1 ๋ฒ์๋ก ์ ๊ทํ""" |
|
|
scores = np.array(scores) |
|
|
min_score = scores.min() |
|
|
max_score = scores.max() |
|
|
|
|
|
if max_score == min_score: |
|
|
return np.full_like(scores, 0.5, dtype=float) |
|
|
|
|
|
return (scores - min_score) / (max_score - min_score) |
|
|
|
|
|
def _find_doc_id_by_content(self, content): |
|
|
"""๋ฌธ์ content๋ก ID ์ฐพ๊ธฐ""" |
|
|
return self.content_to_id.get(content, None) |
|
|
|
|
|
def _rerank(self, query, documents, top_k): |
|
|
""" |
|
|
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ |
|
|
|
|
|
Args: |
|
|
query: ๊ฒ์ ์ฟผ๋ฆฌ |
|
|
documents: hybrid_search ๊ฒฐ๊ณผ ๋ฆฌ์คํธ |
|
|
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์ |
|
|
|
|
|
Returns: |
|
|
์ฌ์ ๋ ฌ๋ ์์ k๊ฐ ๋ฌธ์ |
|
|
""" |
|
|
if len(documents) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
pairs = [[query, doc['content']] for doc in documents] |
|
|
|
|
|
|
|
|
scores = self.reranker.predict(pairs) |
|
|
|
|
|
|
|
|
for i, doc in enumerate(documents): |
|
|
doc['rerank_score'] = float(scores[i]) |
|
|
|
|
|
|
|
|
sorted_docs = sorted(documents, |
|
|
key=lambda x: x['rerank_score'], |
|
|
reverse=True) |
|
|
|
|
|
return sorted_docs[:top_k] |
|
|
|
|
|
@traceable( |
|
|
name="RAG_Hybrid_Search", |
|
|
metadata={"component": "retriever", "version": "2.0"} |
|
|
) |
|
|
def hybrid_search(self, query, top_k=None, alpha=0.5): |
|
|
""" |
|
|
Hybrid Search: BM25 + ์๋ฒ ๋ฉ ๊ฒฐํฉ |
|
|
|
|
|
Args: |
|
|
query: ๊ฒ์ ์ฟผ๋ฆฌ |
|
|
top_k: ๋ฐํํ ๋ฌธ์ ์ |
|
|
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น (0~1) |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
if top_k is None: |
|
|
top_k = self.config.DEFAULT_TOP_K |
|
|
|
|
|
|
|
|
tokenized_query = query.split() |
|
|
bm25_scores = self.bm25.get_scores(tokenized_query) |
|
|
bm25_normalized = self._min_max_normalize(bm25_scores) |
|
|
|
|
|
|
|
|
embedding_results = self.vectorstore.similarity_search_with_score( |
|
|
query, k=min(top_k * 3, len(self.doc_texts)) |
|
|
) |
|
|
|
|
|
|
|
|
embedding_scores_raw = {} |
|
|
for doc, distance in embedding_results: |
|
|
doc_id = self._find_doc_id_by_content(doc.page_content) |
|
|
if doc_id: |
|
|
embedding_scores_raw[doc_id] = 1 / (1 + distance) |
|
|
|
|
|
if embedding_scores_raw: |
|
|
embed_values = np.array(list(embedding_scores_raw.values())) |
|
|
embed_normalized = self._min_max_normalize(embed_values) |
|
|
embedding_scores = dict(zip(embedding_scores_raw.keys(), embed_normalized)) |
|
|
else: |
|
|
embedding_scores = {} |
|
|
|
|
|
|
|
|
hybrid_scores = {} |
|
|
for i, doc_id in enumerate(self.doc_ids): |
|
|
bm25_score = bm25_normalized[i] |
|
|
embed_score = embedding_scores.get(doc_id, 0) |
|
|
hybrid_scores[doc_id] = (1 - alpha) * bm25_score + alpha * embed_score |
|
|
|
|
|
|
|
|
sorted_ids = sorted(hybrid_scores.keys(), |
|
|
key=lambda x: hybrid_scores[x], |
|
|
reverse=True) |
|
|
top_ids = sorted_ids[:top_k] |
|
|
|
|
|
|
|
|
formatted_results = [] |
|
|
for doc_id in top_ids: |
|
|
idx = self.doc_ids.index(doc_id) |
|
|
formatted_results.append({ |
|
|
'content': self.doc_texts[idx], |
|
|
'metadata': self.doc_metadatas[idx], |
|
|
'hybrid_score': hybrid_scores[doc_id], |
|
|
'bm25_score': float(bm25_normalized[idx]), |
|
|
'embed_score': embedding_scores.get(doc_id, 0), |
|
|
'filename': self.doc_metadatas[idx].get('ํ์ผ๋ช
', 'N/A'), |
|
|
'organization': self.doc_metadatas[idx].get('๋ฐ์ฃผ ๊ธฐ๊ด', 'N/A') |
|
|
}) |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"๐ Hybrid ๊ฒ์ ์๋ฃ: {len(formatted_results)}๊ฐ (alpha={alpha}, {end_time-start_time:.3f}์ด)") |
|
|
return formatted_results |
|
|
|
|
|
@traceable( |
|
|
name="RAG_Hybrid_Search_Rerank", |
|
|
metadata={"component": "retriever", "version": "3.0"} |
|
|
) |
|
|
def hybrid_search_with_rerank(self, query, top_k=None, alpha=0.5, rerank_candidates=None): |
|
|
""" |
|
|
Hybrid Search + Re-ranking |
|
|
|
|
|
Args: |
|
|
query: ๊ฒ์ ์ฟผ๋ฆฌ |
|
|
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์ |
|
|
alpha: BM25/์๋ฒ ๋ฉ ๊ฐ์ค์น |
|
|
rerank_candidates: Re-rankํ ํ๋ณด ์ (None์ด๋ฉด top_k * 3) |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
if top_k is None: |
|
|
top_k = self.config.DEFAULT_TOP_K |
|
|
|
|
|
if rerank_candidates is None: |
|
|
rerank_candidates = top_k * 3 |
|
|
|
|
|
|
|
|
candidates = self.hybrid_search(query, top_k=rerank_candidates, alpha=alpha) |
|
|
|
|
|
|
|
|
if len(candidates) > 0: |
|
|
results = self._rerank(query, candidates, top_k) |
|
|
else: |
|
|
results = [] |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"๐ Re-ranking ์๋ฃ: {len(candidates)}๊ฐ โ {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)") |
|
|
|
|
|
return results |
|
|
|
|
|
def search_with_mode(self, query, top_k=None, mode="hybrid_rerank", alpha=0.5): |
|
|
"""๊ฒ์ ๋ชจ๋ ์ ํ""" |
|
|
if mode == "embedding": |
|
|
return self.search(query, top_k) |
|
|
elif mode == "bm25": |
|
|
return self.hybrid_search(query, top_k, alpha=0.0) |
|
|
elif mode == "hybrid": |
|
|
return self.hybrid_search(query, top_k, alpha=alpha) |
|
|
elif mode == "hybrid_rerank": |
|
|
return self.hybrid_search_with_rerank(query, top_k, alpha) |
|
|
else: |
|
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
|
|
@traceable( |
|
|
name="RAG_Retriever_Search", |
|
|
metadata={"component": "retriever", "version": "1.0"} |
|
|
) |
|
|
def search(self, query: str, top_k: int = None, filter_metadata: dict = None): |
|
|
""" |
|
|
์ ์ฌ ๋ฌธ์ ๊ฒ์ (์๋ฒ ๋ฉ ๊ธฐ๋ฐ) |
|
|
""" |
|
|
start_time = time.time() |
|
|
if top_k is None: |
|
|
top_k = self.config.DEFAULT_TOP_K |
|
|
|
|
|
if filter_metadata: |
|
|
results = self.vectorstore.similarity_search_with_score( |
|
|
query, k=top_k, filter=filter_metadata |
|
|
) |
|
|
else: |
|
|
results = self.vectorstore.similarity_search_with_score( |
|
|
query, k=top_k |
|
|
) |
|
|
|
|
|
formatted_results = [] |
|
|
for doc, score in results: |
|
|
formatted_results.append({ |
|
|
'content': doc.page_content, |
|
|
'metadata': doc.metadata, |
|
|
'distance': score, |
|
|
'relevance_score': 1 - score, |
|
|
'filename': doc.metadata.get('ํ์ผ๋ช
', 'N/A'), |
|
|
'organization': doc.metadata.get('๋ฐ์ฃผ ๊ธฐ๊ด', 'N/A') |
|
|
}) |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"๐ ๊ฒ์ ์๋ฃ: {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)") |
|
|
return formatted_results |
|
|
|
|
|
def search_with_rerank(self, query, top_k=None, rerank_candidates=None): |
|
|
""" |
|
|
์๋ฒ ๋ฉ ๊ฒ์ + Re-ranking |
|
|
|
|
|
Args: |
|
|
query: ๊ฒ์ ์ฟผ๋ฆฌ |
|
|
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์ |
|
|
rerank_candidates: Re-rankํ ํ๋ณด ์ |
|
|
|
|
|
Returns: |
|
|
์ฌ์ ๋ ฌ๋ ๋ฌธ์ ๋ฆฌ์คํธ |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
if top_k is None: |
|
|
top_k = self.config.DEFAULT_TOP_K |
|
|
|
|
|
if rerank_candidates is None: |
|
|
rerank_candidates = top_k * 3 |
|
|
|
|
|
|
|
|
candidates = self.search(query, top_k=rerank_candidates) |
|
|
|
|
|
|
|
|
if len(candidates) > 0: |
|
|
results = self._rerank(query, candidates, top_k) |
|
|
else: |
|
|
results = [] |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"๐ Embedding + Re-ranking ์๋ฃ: {len(candidates)}๊ฐ โ {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)") |
|
|
|
|
|
return results |
|
|
|
|
|
def search_by_organization(self, query: str, organization: str, top_k: int = None): |
|
|
"""ํน์ ๋ฐ์ฃผ๊ธฐ๊ด๋ง ๊ฒ์""" |
|
|
return self.search( |
|
|
query, top_k=top_k, filter_metadata={'๋ฐ์ฃผ ๊ธฐ๊ด': organization} |
|
|
) |
|
|
|
|
|
def get_retriever(self): |
|
|
"""LangChain ์ฒด์ธ์ฉ Retriever ๋ฐํ""" |
|
|
return self.vectorstore.as_retriever( |
|
|
search_type="similarity", |
|
|
search_kwargs={"k": self.config.DEFAULT_TOP_K} |
|
|
) |