Dongjin1203's picture
generator μ½”λ“œμˆ˜μ •
255beb5
raw
history blame
14.9 kB
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage
from langsmith import traceable
import time
from typing import List, Dict
from src.utils.config import RAGConfig
from src.retriever.retriever import RAGRetriever
from src.router.query_router import QueryRouter
class RAGPipeline:
"""λŒ€ν™”ν˜• RAG νŒŒμ΄ν”„λΌμΈ - LangChain Chain 기반"""
def __init__(self, config: RAGConfig = None, model: str = None, top_k: int = None):
"""μ΄ˆκΈ°ν™”"""
self.config = config or RAGConfig()
self.model = model or self.config.LLM_MODEL_NAME
self.top_k = top_k or self.config.DEFAULT_TOP_K
# 검색 μ„€μ •
self.search_mode = self.config.DEFAULT_SEARCH_MODE
self.alpha = self.config.DEFAULT_ALPHA
# LLM μ΄ˆκΈ°ν™” (LangChain ChatOpenAI)
self.llm = ChatOpenAI(
model=self.model,
openai_api_key=self.config.OPENAI_API_KEY,
timeout=60.0,
max_retries=3
)
# Retriever 및 λΌμš°ν„° μ΄ˆκΈ°ν™”
self.retriever = RAGRetriever(config=self.config)
self.router = QueryRouter()
self._direct_responses = {
'greeting': "μ•ˆλ…•ν•˜μ„Έμš”! κ³΅κ³΅μž…μ°° RFP κ΄€λ ¨ κΆκΈˆν•œ 사항을 μ•Œλ €μ£Όμ‹œλ©΄ 자료λ₯Ό μ°Ύμ•„ λ“œλ¦΄κ²Œμš”.",
'thanks': "도움이 λ˜μ—ˆλ‹€λ‹ˆ λ‹€ν–‰μž…λ‹ˆλ‹€. μΆ”κ°€λ‘œ κΆκΈˆν•œ 점이 있으면 μ–Έμ œλ“ μ§€ 말씀해 μ£Όμ„Έμš”!",
'out_of_scope': "ν•΄λ‹Ή μ§ˆλ¬Έμ€ ν˜„μž¬ λ³΄μœ ν•œ μž…μ°°Β·μ‚¬μ—… λ¬Έμ„œμ—μ„œ 닀루지 μ•ŠμŠ΅λ‹ˆλ‹€. λ‹€λ₯Έ μ§ˆλ¬Έμ„ μ‹œλ„ν•΄ μ£Όμ„Έμš”."
}
# λŒ€ν™” νžˆμŠ€ν† λ¦¬
self.chat_history: List[Dict] = []
# λ§ˆμ§€λ§‰ 검색 κ²°κ³Ό μ €μž₯ (sources λ°˜ν™˜μš©)
self._last_retrieved_docs = []
# ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώ (λŒ€ν™” νžˆμŠ€ν† λ¦¬ 포함)
self.prompt = ChatPromptTemplate.from_messages([
("system", """당신은 κ³΅κ³΅μž…μ°° RFPλ₯Ό λΆ„μ„ν•˜λŠ” μž…μ°°λ©”μ΄νŠΈ 사내 λΆ„μ„κ°€μž…λ‹ˆλ‹€. 제곡된 μ»¨ν…μŠ€νŠΈλ§ŒμœΌλ‘œ μš”κ΅¬μ‚¬ν•­Β·μ˜ˆμ‚°Β·λŒ€μƒ κΈ°κ΄€Β·μ œμΆœ 방식 등을 ꡬ쑰화해 μ˜μ‚¬κ²°μ •μ„ μ§€μ›ν•˜μ„Έμš”.
# κ·œμΉ™
- 닡변은 ν•œκ΅­μ–΄λ‘œ μž‘μ„±ν•©λ‹ˆλ‹€.
- μ»¨ν…μŠ€νŠΈ λ°– λ‚΄μš©μ„ μΆ”μΈ‘ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
- μ»¨ν…μŠ€νŠΈκ°€ λΉ„μ–΄μžˆκ±°λ‚˜ 질문과 직접 κ΄€λ ¨λœ 사싀이 μ—†μœΌλ©΄ "λ¬Έμ„œμ—μ„œ ν•΄λ‹Ή 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€." ν•œ λ¬Έμž₯으둜만 λ‹΅ν•©λ‹ˆλ‹€.
- μ—¬λŸ¬ λ¬Έμ„œλ₯Ό 비ꡐ할 λ•ŒλŠ” λ¬Έμ„œλ³„ 차이λ₯Ό ν‘œ λ˜λŠ” λͺ©λ‘μœΌλ‘œ μ •λ¦¬ν•©λ‹ˆλ‹€.
- μˆ«μžμ—λŠ” κ°€λŠ₯ν•œ λ‹¨μœ„λ₯Ό ν¬ν•¨ν•©λ‹ˆλ‹€.
- 직전 λŒ€ν™” λ§₯락을 λ°˜μ˜ν•˜λ˜, ν™•μΈλ˜μ§€ μ•Šμ€ λ‚΄μš©μ„ μΆ”λ‘ ν•΄ μΆ”κ°€ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
# λ‹΅λ³€ ν˜•μ‹
1. ν•œ 쀄 μš”μ•½: 질문 핡심을 ν•œλ‘ λ¬Έμž₯으둜 μž‘μ„±ν•©λ‹ˆλ‹€.
2. 상세 λ‹΅λ³€: [μš”κ΅¬μ‚¬ν•­], [λŒ€μƒ κΈ°κ΄€], [μ˜ˆμ‚°], [제좜 ν˜•μ‹/방법], [평가 κΈ°μ€€] λ“± λ¬Έμ„œμ—μ„œ ν™•μΈλœ ν•­λͺ©λ§Œ μ •λ¦¬ν•©λ‹ˆλ‹€.
3. κ·Όκ±° 정보: μœ„ λ‹΅λ³€μ˜ κ·Όκ±°κ°€ 된 λ¬Έμž₯μ΄λ‚˜ 문단을 μš”μ•½ν•©λ‹ˆλ‹€.
4. λΆ€μ‘±ν•œ 정보: λ¬Έμ„œμ—μ„œ 찾을 수 μ—†λŠ” ν•­λͺ©μ€ "λ¬Έμ„œμ—μ„œ 확인 λΆˆκ°€"둜 ν‘œκΈ°ν•©λ‹ˆλ‹€."""),
# λŒ€ν™” νžˆμŠ€ν† λ¦¬
MessagesPlaceholder(variable_name="chat_history"),
# ν˜„μž¬ 질문과 μ»¨ν…μŠ€νŠΈ
("user", """# μ»¨ν…μŠ€νŠΈ
{context}
# 질문
{question}
μœ„ κ·œμΉ™μ— 따라 λ‹΅λ³€ν•˜μ„Έμš”.""")
])
# Chain ꡬ성
self.chain = (
{
"context": RunnableLambda(self._retrieve_and_format),
"question": RunnablePassthrough(),
"chat_history": RunnableLambda(lambda x: self._get_chat_history())
}
| self.prompt
| self.llm
| StrOutputParser()
)
print(f"βœ… RAG νŒŒμ΄ν”„λΌμΈ μ΄ˆκΈ°ν™” μ™„λ£Œ")
print(f" - λͺ¨λΈ: {self.model}")
print(f" - κΈ°λ³Έ top_k: {self.top_k}")
print(f" - 검색 λͺ¨λ“œ: {self.search_mode}")
def _get_chat_history(self) -> List:
"""λŒ€ν™” νžˆμŠ€ν† λ¦¬λ₯Ό LangChain λ©”μ‹œμ§€ ν˜•μ‹μœΌλ‘œ λ³€ν™˜"""
messages = []
for msg in self.chat_history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
else:
messages.append(AIMessage(content=msg["content"]))
return messages
def _retrieve_and_format(self, query: str) -> str:
"""검색 μˆ˜ν–‰ 및 μ»¨ν…μŠ€νŠΈ ν¬λ§·νŒ…"""
# 검색 λͺ¨λ“œμ— 따라 λ¬Έμ„œ 검색
if self.search_mode == "embedding":
docs = self.retriever.search(query, top_k=self.top_k)
elif self.search_mode == "hybrid":
docs = self.retriever.hybrid_search(query, top_k=self.top_k, alpha=self.alpha)
elif self.search_mode == "hybrid_rerank":
docs = self.retriever.hybrid_search_with_rerank(
query, top_k=self.top_k, alpha=self.alpha
)
else:
docs = self.retriever.search(query, top_k=self.top_k)
# λ§ˆμ§€λ§‰ 검색 κ²°κ³Ό μ €μž₯
self._last_retrieved_docs = docs
# μ»¨ν…μŠ€νŠΈ ν¬λ§·νŒ…
return self._format_context(docs)
def _format_context(self, retrieved_docs: list) -> str:
"""κ²€μƒ‰λœ λ¬Έμ„œλ₯Ό μ»¨ν…μŠ€νŠΈλ‘œ λ³€ν™˜"""
if not retrieved_docs:
return "κ΄€λ ¨ λ¬Έμ„œλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€."
context_parts = []
for i, doc in enumerate(retrieved_docs, 1):
context_parts.append(f"[λ¬Έμ„œ {i}]\n{doc['content']}\n")
return "\n".join(context_parts)
def _format_sources(self, retrieved_docs: list) -> list:
"""κ²€μƒ‰λœ λ¬Έμ„œλ₯Ό sources ν˜•μ‹μœΌλ‘œ λ³€ν™˜"""
sources = []
for doc in retrieved_docs:
source_info = {
'content': doc['content'],
'metadata': doc['metadata'],
'filename': doc.get('filename', 'N/A'),
'organization': doc.get('organization', 'N/A')
}
# 검색 λͺ¨λ“œμ— 따라 점수 ν•„λ“œκ°€ 닀름
if 'rerank_score' in doc:
source_info['score'] = doc['rerank_score']
source_info['score_type'] = 'rerank'
elif 'hybrid_score' in doc:
source_info['score'] = doc['hybrid_score']
source_info['score_type'] = 'hybrid'
elif 'relevance_score' in doc:
source_info['score'] = doc['relevance_score']
source_info['score_type'] = 'embedding'
else:
source_info['score'] = 0
source_info['score_type'] = 'unknown'
sources.append(source_info)
return sources
@traceable(
name="RAG_Generate_Answer",
metadata={"component": "generator", "version": "2.0"}
)
def generate_answer(
self,
query: str,
top_k: int = None,
search_mode: str = None,
alpha: float = None
) -> dict:
"""
λ‹΅λ³€ 생성 (Chain 기반)
Args:
query: 질문
top_k: 검색할 λ¬Έμ„œ 수
search_mode: 검색 λͺ¨λ“œ ("embedding", "hybrid", "hybrid_rerank")
alpha: μž„λ² λ”© κ°€μ€‘μΉ˜ (0~1)
Returns:
dict: answer, sources, search_mode, usage
"""
try:
start_time = time.time()
classification = self.router.classify(query)
query_type = classification.get('type', 'document')
# λΉ„λ¬Έμ„œ μ§ˆμ˜λŠ” μ¦‰μ‹œ 응닡
if query_type != 'document':
print(f"⏭️ λΌμš°ν„°: 검색 μƒλž΅ ({query_type})")
answer = self._direct_responses.get(
query_type,
self._direct_responses['out_of_scope']
)
elapsed_time = time.time() - start_time
self._last_retrieved_docs = []
self.chat_history.append({"role": "user", "content": query})
self.chat_history.append({"role": "assistant", "content": answer})
return {
'answer': answer,
'sources': [],
'search_mode': 'none',
'elapsed_time': elapsed_time,
'usage': {
'total_tokens': 0,
'prompt_tokens': 0,
'completion_tokens': 0
},
'routing': classification
}
# νŒŒλΌλ―Έν„° μ„€μ •
if top_k is not None:
self.top_k = top_k
if search_mode is not None:
self.search_mode = search_mode
if alpha is not None:
self.alpha = alpha
# Chain μ‹€ν–‰
answer = self.chain.invoke(query)
# 검색 κ²°κ³Όκ°€ μ—†μœΌλ©΄ μ•ˆμ „ μ‘λ‹΅μœΌλ‘œ λŒ€μ²΄
if not self._last_retrieved_docs:
answer = "λ¬Έμ„œμ—μ„œ κ΄€λ ¨ 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. λ‹€λ₯Έ μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ μ£Όμ„Έμš”."
print("⚠️ 검색 κ²°κ³Ό μ—†μŒ - μ•ˆμ „ 응닡 λ°˜ν™˜")
elapsed_time = time.time() - start_time
# λŒ€ν™” νžˆμŠ€ν† λ¦¬μ— μΆ”κ°€
self.chat_history.append({"role": "user", "content": query})
self.chat_history.append({"role": "assistant", "content": answer})
# 토큰 μ‚¬μš©λŸ‰ μΆ”μ • (LangChainμ—μ„œλŠ” 직접 μ ‘κ·Ό 어렀움)
estimated_tokens = len(query.split()) + len(answer.split()) * 2
return {
'answer': answer,
'sources': self._format_sources(self._last_retrieved_docs),
'search_mode': self.search_mode,
'elapsed_time': elapsed_time,
'usage': {
'total_tokens': estimated_tokens,
'prompt_tokens': 0,
'completion_tokens': 0
},
'routing': classification
}
except Exception as e:
print(f"❌ λ‹΅λ³€ 생성 μ‹€νŒ¨: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"λ‹΅λ³€ 생성 μ‹€νŒ¨: {str(e)}") from e
def chat(self, query: str) -> str:
"""
κ°„λ‹¨ν•œ λŒ€ν™” μΈν„°νŽ˜μ΄μŠ€
Args:
query: 질문
Returns:
str: λ‹΅λ³€ ν…μŠ€νŠΈλ§Œ λ°˜ν™˜
"""
result = self.generate_answer(query)
return result['answer']
def clear_history(self):
"""λŒ€ν™” νžˆμŠ€ν† λ¦¬ μ΄ˆκΈ°ν™”"""
self.chat_history = []
print("πŸ—‘οΈ λŒ€ν™” νžˆμŠ€ν† λ¦¬κ°€ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
def get_history(self) -> List[Dict]:
"""λŒ€ν™” νžˆμŠ€ν† λ¦¬ λ°˜ν™˜"""
return self.chat_history.copy()
def set_search_config(self, search_mode: str = None, top_k: int = None, alpha: float = None):
"""검색 μ„€μ • λ³€κ²½"""
if search_mode is not None:
self.search_mode = search_mode
if top_k is not None:
self.top_k = top_k
if alpha is not None:
self.alpha = alpha
print(f"πŸ”§ 검색 μ„€μ • λ³€κ²½: mode={self.search_mode}, top_k={self.top_k}, alpha={self.alpha}")
def print_result(self, result: dict, query: str = None):
"""κ²°κ³Ό 좜λ ₯"""
print("\n" + "="*60)
if query:
print(f"질문: {query}")
print(f"검색 λͺ¨λ“œ: {result.get('search_mode', 'N/A')}")
if 'elapsed_time' in result:
print(f"μ†Œμš” μ‹œκ°„: {result['elapsed_time']:.2f}초")
print("="*60)
print(f"\nπŸ’¬ λ‹΅λ³€:\n{result['answer']}")
print(f"\nπŸ“š μ°Έκ³  λ¬Έμ„œ ({len(result['sources'])}개):")
for i, source in enumerate(result['sources'], 1):
score = source.get('score', 0)
score_type = source.get('score_type', '')
print(f" [{i}] {source['filename']}")
print(f" 점수: {score:.3f} ({score_type})")
print("="*60)
# λŒ€ν™”ν˜• μ‹€ν–‰
def interactive_mode():
"""λŒ€ν™”ν˜• λͺ¨λ“œ μ‹€ν–‰"""
print("=" * 60)
print("λŒ€ν™”ν˜• RAG μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” 쀑...")
print("=" * 60)
config = RAGConfig()
pipeline = RAGPipeline(config=config)
print("\n" + "=" * 60)
print("λŒ€ν™”ν˜• λͺ¨λ“œ μ‹œμž‘")
print("λͺ…λ Ήμ–΄: 'quit' (μ’…λ£Œ), 'clear' (νžˆμŠ€ν† λ¦¬ μ΄ˆκΈ°ν™”), 'mode' (검색λͺ¨λ“œ λ³€κ²½)")
print("=" * 60)
while True:
user_query = input("\n질문: ").strip()
if not user_query:
continue
if user_query.lower() in ['quit', 'exit', 'μ’…λ£Œ', 'q']:
print("μ‹œμŠ€ν…œμ„ μ’…λ£Œν•©λ‹ˆλ‹€.")
break
if user_query.lower() == 'clear':
pipeline.clear_history()
continue
if user_query.lower() == 'mode':
print("\n검색 λͺ¨λ“œ 선택:")
print("1. embedding - μž„λ² λ”© 검색")
print("2. hybrid - BM25 + μž„λ² λ”©")
print("3. hybrid_rerank - Hybrid + Re-ranker (ꢌμž₯)")
choice = input("선택 (1/2/3): ").strip()
modes = {'1': 'embedding', '2': 'hybrid', '3': 'hybrid_rerank'}
if choice in modes:
pipeline.set_search_config(search_mode=modes[choice])
continue
try:
result = pipeline.generate_answer(query=user_query)
pipeline.print_result(result, user_query)
# μ†ŒμŠ€ 좜λ ₯ μ—¬λΆ€
show_source = input("\nμ°Έμ‘° λ¬Έμ„œ 상세 보기? (y/n): ").strip().lower()
if show_source == 'y':
for i, source in enumerate(result['sources'], 1):
print(f"\n{'='*40}")
print(f"[λ¬Έμ„œ {i}] {source['filename']}")
print(f"λ°œμ£ΌκΈ°κ΄€: {source['organization']}")
print(f"λ‚΄μš©:\n{source['content'][:500]}...")
except Exception as e:
print(f"❌ 였λ₯˜ λ°œμƒ: {e}")
# μ‚¬μš© μ˜ˆμ‹œ
if __name__ == "__main__":
interactive_mode()