Dongjin1203's picture
Initial commit for HF Spaces deployment
4739096
raw
history blame
10.9 kB
from langchain_chroma import Chroma
from langchain_openai.embeddings import OpenAIEmbeddings
from langsmith import traceable
import time
import os
from rank_bm25 import BM25Okapi
import numpy as np
from sentence_transformers import CrossEncoder
from src.utils.config import RAGConfig
class RAGRetriever:
"""RAG ๊ฒ€์ƒ‰ ์‹œ์Šคํ…œ (Hybrid Search + Re-ranker)"""
def __init__(self, config: RAGConfig = None):
self.config = config or RAGConfig()
self.vectorstore = None
self.embeddings = None
self._initialize_embeddings()
self._create_vectorstore()
self._initialize_bm25()
self._initialize_reranker()
def _initialize_embeddings(self):
"""์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”"""
os.environ["OPENAI_API_KEY"] = self.config.OPENAI_API_KEY
self.embeddings = OpenAIEmbeddings(
model=self.config.EMBEDDING_MODEL_NAME
)
def _create_vectorstore(self):
"""๊ธฐ์กด ๋ฒกํ„ฐ์Šคํ† ์–ด ๋กœ๋“œ"""
self.vectorstore = Chroma(
embedding_function=self.embeddings,
persist_directory=self.config.DB_DIRECTORY,
collection_name=self.config.COLLECTION_NAME
)
def _initialize_bm25(self):
"""BM25 ์ธ๋ฑ์Šค ์ƒ์„ฑ"""
all_docs = self.vectorstore.get()
self.doc_texts = all_docs['documents']
self.doc_ids = all_docs['ids']
self.doc_metadatas = all_docs['metadatas']
self.content_to_id = {text: doc_id for text, doc_id in zip(self.doc_texts, self.doc_ids)}
tokenized_docs = [doc.split() for doc in self.doc_texts]
self.bm25 = BM25Okapi(tokenized_docs)
print(f"โœ… BM25 ์ธ๋ฑ์Šค ์ƒ์„ฑ ์™„๋ฃŒ: {len(self.doc_texts)}๊ฐœ ๋ฌธ์„œ")
def _initialize_reranker(self):
"""Re-ranker ์ดˆ๊ธฐํ™”"""
self.reranker = CrossEncoder('BAAI/bge-reranker-base')
print("โœ… Re-ranker ์ดˆ๊ธฐํ™” ์™„๋ฃŒ (bge-reranker-base)")
@staticmethod
def _min_max_normalize(scores):
"""0~1 ๋ฒ”์œ„๋กœ ์ •๊ทœํ™”"""
scores = np.array(scores)
min_score = scores.min()
max_score = scores.max()
if max_score == min_score:
return np.full_like(scores, 0.5, dtype=float)
return (scores - min_score) / (max_score - min_score)
def _find_doc_id_by_content(self, content):
"""๋ฌธ์„œ content๋กœ ID ์ฐพ๊ธฐ"""
return self.content_to_id.get(content, None)
def _rerank(self, query, documents, top_k):
"""
๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์žฌ์ •๋ ฌ
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
documents: hybrid_search ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
top_k: ์ตœ์ข… ๋ฐ˜ํ™˜ํ•  ๋ฌธ์„œ ์ˆ˜
Returns:
์žฌ์ •๋ ฌ๋œ ์ƒ์œ„ k๊ฐœ ๋ฌธ์„œ
"""
if len(documents) == 0:
return []
# 1. (query, document) ์Œ ์ƒ์„ฑ
pairs = [[query, doc['content']] for doc in documents]
# 2. CrossEncoder๋กœ ์ ์ˆ˜ ๊ณ„์‚ฐ
scores = self.reranker.predict(pairs)
# 3. ์ ์ˆ˜๋ฅผ ๋ฌธ์„œ์— ์ถ”๊ฐ€
for i, doc in enumerate(documents):
doc['rerank_score'] = float(scores[i])
# 4. ์ •๋ ฌ ๋ฐ ๋ฐ˜ํ™˜
sorted_docs = sorted(documents,
key=lambda x: x['rerank_score'],
reverse=True)
return sorted_docs[:top_k]
@traceable(
name="RAG_Hybrid_Search",
metadata={"component": "retriever", "version": "2.0"}
)
def hybrid_search(self, query, top_k=None, alpha=0.5):
"""
Hybrid Search: BM25 + ์ž„๋ฒ ๋”ฉ ๊ฒฐํ•ฉ
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
top_k: ๋ฐ˜ํ™˜ํ•  ๋ฌธ์„œ ์ˆ˜
alpha: ์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜ (0~1)
"""
start_time = time.time()
if top_k is None:
top_k = self.config.DEFAULT_TOP_K
# 1. BM25 ๊ฒ€์ƒ‰
tokenized_query = query.split()
bm25_scores = self.bm25.get_scores(tokenized_query)
bm25_normalized = self._min_max_normalize(bm25_scores)
# 2. ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰
embedding_results = self.vectorstore.similarity_search_with_score(
query, k=min(top_k * 3, len(self.doc_texts))
)
# 3. ์ž„๋ฒ ๋”ฉ ์ ์ˆ˜ ์ •๊ทœํ™”
embedding_scores_raw = {}
for doc, distance in embedding_results:
doc_id = self._find_doc_id_by_content(doc.page_content)
if doc_id:
embedding_scores_raw[doc_id] = 1 / (1 + distance)
if embedding_scores_raw:
embed_values = np.array(list(embedding_scores_raw.values()))
embed_normalized = self._min_max_normalize(embed_values)
embedding_scores = dict(zip(embedding_scores_raw.keys(), embed_normalized))
else:
embedding_scores = {}
# 4. ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ์ ์ˆ˜ ๊ณ„์‚ฐ
hybrid_scores = {}
for i, doc_id in enumerate(self.doc_ids):
bm25_score = bm25_normalized[i]
embed_score = embedding_scores.get(doc_id, 0)
hybrid_scores[doc_id] = (1 - alpha) * bm25_score + alpha * embed_score
# 5. ์ •๋ ฌ ๋ฐ ์ƒ์œ„ k๊ฐœ ์„ ํƒ
sorted_ids = sorted(hybrid_scores.keys(),
key=lambda x: hybrid_scores[x],
reverse=True)
top_ids = sorted_ids[:top_k]
# 6. ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
formatted_results = []
for doc_id in top_ids:
idx = self.doc_ids.index(doc_id)
formatted_results.append({
'content': self.doc_texts[idx],
'metadata': self.doc_metadatas[idx],
'hybrid_score': hybrid_scores[doc_id],
'bm25_score': float(bm25_normalized[idx]),
'embed_score': embedding_scores.get(doc_id, 0),
'filename': self.doc_metadatas[idx].get('ํŒŒ์ผ๋ช…', 'N/A'),
'organization': self.doc_metadatas[idx].get('๋ฐœ์ฃผ ๊ธฐ๊ด€', 'N/A')
})
end_time = time.time()
print(f"๐Ÿ” Hybrid ๊ฒ€์ƒ‰ ์™„๋ฃŒ: {len(formatted_results)}๊ฐœ (alpha={alpha}, {end_time-start_time:.3f}์ดˆ)")
return formatted_results
@traceable(
name="RAG_Hybrid_Search_Rerank",
metadata={"component": "retriever", "version": "3.0"}
)
def hybrid_search_with_rerank(self, query, top_k=None, alpha=0.5, rerank_candidates=None):
"""
Hybrid Search + Re-ranking
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
top_k: ์ตœ์ข… ๋ฐ˜ํ™˜ํ•  ๋ฌธ์„œ ์ˆ˜
alpha: BM25/์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜
rerank_candidates: Re-rankํ•  ํ›„๋ณด ์ˆ˜ (None์ด๋ฉด top_k * 3)
"""
start_time = time.time()
if top_k is None:
top_k = self.config.DEFAULT_TOP_K
if rerank_candidates is None:
rerank_candidates = top_k * 3
# 1. Hybrid Search๋กœ ํ›„๋ณด ๋ฌธ์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
candidates = self.hybrid_search(query, top_k=rerank_candidates, alpha=alpha)
# 2. Re-ranking
if len(candidates) > 0:
results = self._rerank(query, candidates, top_k)
else:
results = []
end_time = time.time()
print(f"๐Ÿ”„ Re-ranking ์™„๋ฃŒ: {len(candidates)}๊ฐœ โ†’ {len(results)}๊ฐœ ({end_time-start_time:.3f}์ดˆ)")
return results
def search_with_mode(self, query, top_k=None, mode="hybrid_rerank", alpha=0.5):
"""๊ฒ€์ƒ‰ ๋ชจ๋“œ ์„ ํƒ"""
if mode == "embedding":
return self.search(query, top_k)
elif mode == "bm25":
return self.hybrid_search(query, top_k, alpha=0.0)
elif mode == "hybrid":
return self.hybrid_search(query, top_k, alpha=alpha)
elif mode == "hybrid_rerank":
return self.hybrid_search_with_rerank(query, top_k, alpha)
else:
raise ValueError(f"Unknown mode: {mode}")
@traceable(
name="RAG_Retriever_Search",
metadata={"component": "retriever", "version": "1.0"}
)
def search(self, query: str, top_k: int = None, filter_metadata: dict = None):
"""
์œ ์‚ฌ ๋ฌธ์„œ ๊ฒ€์ƒ‰ (์ž„๋ฒ ๋”ฉ ๊ธฐ๋ฐ˜)
"""
start_time = time.time()
if top_k is None:
top_k = self.config.DEFAULT_TOP_K
if filter_metadata:
results = self.vectorstore.similarity_search_with_score(
query, k=top_k, filter=filter_metadata
)
else:
results = self.vectorstore.similarity_search_with_score(
query, k=top_k
)
formatted_results = []
for doc, score in results:
formatted_results.append({
'content': doc.page_content,
'metadata': doc.metadata,
'distance': score,
'relevance_score': 1 - score,
'filename': doc.metadata.get('ํŒŒ์ผ๋ช…', 'N/A'),
'organization': doc.metadata.get('๋ฐœ์ฃผ ๊ธฐ๊ด€', 'N/A')
})
end_time = time.time()
print(f"๐Ÿ” ๊ฒ€์ƒ‰ ์™„๋ฃŒ: {len(results)}๊ฐœ ({end_time-start_time:.3f}์ดˆ)")
return formatted_results
def search_with_rerank(self, query, top_k=None, rerank_candidates=None):
"""
์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ + Re-ranking
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
top_k: ์ตœ์ข… ๋ฐ˜ํ™˜ํ•  ๋ฌธ์„œ ์ˆ˜
rerank_candidates: Re-rankํ•  ํ›„๋ณด ์ˆ˜
Returns:
์žฌ์ •๋ ฌ๋œ ๋ฌธ์„œ ๋ฆฌ์ŠคํŠธ
"""
start_time = time.time()
if top_k is None:
top_k = self.config.DEFAULT_TOP_K
if rerank_candidates is None:
rerank_candidates = top_k * 3
# 1. ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰์œผ๋กœ ํ›„๋ณด ๊ฐ€์ ธ์˜ค๊ธฐ
candidates = self.search(query, top_k=rerank_candidates)
# 2. Re-ranking
if len(candidates) > 0:
results = self._rerank(query, candidates, top_k)
else:
results = []
end_time = time.time()
print(f"๐Ÿ”„ Embedding + Re-ranking ์™„๋ฃŒ: {len(candidates)}๊ฐœ โ†’ {len(results)}๊ฐœ ({end_time-start_time:.3f}์ดˆ)")
return results
def search_by_organization(self, query: str, organization: str, top_k: int = None):
"""ํŠน์ • ๋ฐœ์ฃผ๊ธฐ๊ด€๋งŒ ๊ฒ€์ƒ‰"""
return self.search(
query, top_k=top_k, filter_metadata={'๋ฐœ์ฃผ ๊ธฐ๊ด€': organization}
)
def get_retriever(self):
"""LangChain ์ฒด์ธ์šฉ Retriever ๋ฐ˜ํ™˜"""
return self.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": self.config.DEFAULT_TOP_K}
)