|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
from langsmith import traceable |
|
|
import time |
|
|
from typing import List, Dict |
|
|
|
|
|
from src.utils.config import RAGConfig |
|
|
from src.retriever.retriever import RAGRetriever |
|
|
|
|
|
|
|
|
class RAGPipeline: |
|
|
"""λνν RAG νμ΄νλΌμΈ - LangChain Chain κΈ°λ°""" |
|
|
|
|
|
def __init__(self, config: RAGConfig = None, model: str = None, top_k: int = None): |
|
|
"""μ΄κΈ°ν""" |
|
|
self.config = config or RAGConfig() |
|
|
self.model = model or self.config.LLM_MODEL_NAME |
|
|
self.top_k = top_k or self.config.DEFAULT_TOP_K |
|
|
|
|
|
|
|
|
self.search_mode = self.config.DEFAULT_SEARCH_MODE |
|
|
self.alpha = self.config.DEFAULT_ALPHA |
|
|
|
|
|
|
|
|
self.llm = ChatOpenAI( |
|
|
model=self.model, |
|
|
openai_api_key=self.config.OPENAI_API_KEY, |
|
|
timeout=60.0, |
|
|
max_retries=3 |
|
|
) |
|
|
|
|
|
|
|
|
self.retriever = RAGRetriever(config=self.config) |
|
|
|
|
|
|
|
|
self.chat_history: List[Dict] = [] |
|
|
|
|
|
|
|
|
self._last_retrieved_docs = [] |
|
|
|
|
|
|
|
|
self.prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """λΉμ μ 곡곡μ
μ°° RFPλ₯Ό λΆμνλ μ
μ°°λ©μ΄νΈ μ¬λ΄ λΆμκ°μ
λλ€. μ 곡λ 컨ν
μ€νΈλ§μΌλ‘ μꡬμ¬νΒ·μμ°Β·λμ κΈ°κ΄Β·μ μΆ λ°©μ λ±μ ꡬ쑰νν΄ μμ¬κ²°μ μ μ§μνμΈμ. |
|
|
|
|
|
# κ·μΉ |
|
|
- λ΅λ³μ νκ΅μ΄λ‘ μμ±ν©λλ€. |
|
|
- 컨ν
μ€νΈ λ° λ΄μ©μ μΆμΈ‘νμ§ μμ΅λλ€. |
|
|
- μ λ³΄κ° μμΌλ©΄ "λ¬Έμμμ ν΄λΉ μ 보λ₯Ό μ°Ύμ μ μμ΅λλ€."λΌκ³ λ°νλλ€. |
|
|
- μ¬λ¬ λ¬Έμλ₯Ό λΉκ΅ν λλ λ¬Έμλ³ μ°¨μ΄λ₯Ό ν λλ λͺ©λ‘μΌλ‘ μ 리ν©λλ€. |
|
|
- μ«μμλ κ°λ₯ν λ¨μλ₯Ό ν¬ν¨ν©λλ€. |
|
|
- μ§μ λν λ§₯λ½μ λ°μν©λλ€. |
|
|
|
|
|
# λ΅λ³ νμ |
|
|
1. ν μ€ μμ½: μ§λ¬Έ ν΅μ¬μ νλ λ¬Έμ₯μΌλ‘ μμ±ν©λλ€. |
|
|
2. μμΈ λ΅λ³: [μꡬμ¬ν], [λμ κΈ°κ΄], [μμ°], [μ μΆ νμ/λ°©λ²], [νκ° κΈ°μ€] λ± λ¬Έμμμ νμΈλ νλͺ©λ§ μ 리ν©λλ€. |
|
|
3. κ·Όκ±° μ 보: μ λ΅λ³μ κ·Όκ±°κ° λ λ¬Έμ₯μ΄λ λ¬Έλ¨μ μμ½ν©λλ€. |
|
|
4. λΆμ‘±ν μ 보: λ¬Έμμμ μ°Ύμ μ μλ νλͺ©μ "λ¬Έμμμ νμΈ λΆκ°"λ‘ νκΈ°ν©λλ€."""), |
|
|
|
|
|
|
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
|
|
|
|
|
|
("user", """# 컨ν
μ€νΈ |
|
|
{context} |
|
|
|
|
|
# μ§λ¬Έ |
|
|
{question} |
|
|
|
|
|
μ κ·μΉμ λ°λΌ λ΅λ³νμΈμ.""") |
|
|
]) |
|
|
|
|
|
|
|
|
self.chain = ( |
|
|
{ |
|
|
"context": RunnableLambda(self._retrieve_and_format), |
|
|
"question": RunnablePassthrough(), |
|
|
"chat_history": RunnableLambda(lambda x: self._get_chat_history()) |
|
|
} |
|
|
| self.prompt |
|
|
| self.llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
|
|
|
print(f"β
RAG νμ΄νλΌμΈ μ΄κΈ°ν μλ£") |
|
|
print(f" - λͺ¨λΈ: {self.model}") |
|
|
print(f" - κΈ°λ³Έ top_k: {self.top_k}") |
|
|
print(f" - κ²μ λͺ¨λ: {self.search_mode}") |
|
|
|
|
|
def _get_chat_history(self) -> List: |
|
|
"""λν νμ€ν 리λ₯Ό LangChain λ©μμ§ νμμΌλ‘ λ³ν""" |
|
|
messages = [] |
|
|
for msg in self.chat_history: |
|
|
if msg["role"] == "user": |
|
|
messages.append(HumanMessage(content=msg["content"])) |
|
|
else: |
|
|
messages.append(AIMessage(content=msg["content"])) |
|
|
return messages |
|
|
|
|
|
def _retrieve_and_format(self, query: str) -> str: |
|
|
"""κ²μ μν λ° μ»¨ν
μ€νΈ ν¬λ§·ν
""" |
|
|
|
|
|
if self.search_mode == "embedding": |
|
|
docs = self.retriever.search(query, top_k=self.top_k) |
|
|
elif self.search_mode == "hybrid": |
|
|
docs = self.retriever.hybrid_search(query, top_k=self.top_k, alpha=self.alpha) |
|
|
elif self.search_mode == "hybrid_rerank": |
|
|
docs = self.retriever.hybrid_search_with_rerank( |
|
|
query, top_k=self.top_k, alpha=self.alpha |
|
|
) |
|
|
else: |
|
|
docs = self.retriever.search(query, top_k=self.top_k) |
|
|
|
|
|
|
|
|
self._last_retrieved_docs = docs |
|
|
|
|
|
|
|
|
return self._format_context(docs) |
|
|
|
|
|
def _format_context(self, retrieved_docs: list) -> str: |
|
|
"""κ²μλ λ¬Έμλ₯Ό 컨ν
μ€νΈλ‘ λ³ν""" |
|
|
if not retrieved_docs: |
|
|
return "κ΄λ ¨ λ¬Έμλ₯Ό μ°Ύμ μ μμ΅λλ€." |
|
|
|
|
|
context_parts = [] |
|
|
for i, doc in enumerate(retrieved_docs, 1): |
|
|
context_parts.append(f"[λ¬Έμ {i}]\n{doc['content']}\n") |
|
|
return "\n".join(context_parts) |
|
|
|
|
|
def _format_sources(self, retrieved_docs: list) -> list: |
|
|
"""κ²μλ λ¬Έμλ₯Ό sources νμμΌλ‘ λ³ν""" |
|
|
sources = [] |
|
|
for doc in retrieved_docs: |
|
|
source_info = { |
|
|
'content': doc['content'], |
|
|
'metadata': doc['metadata'], |
|
|
'filename': doc.get('filename', 'N/A'), |
|
|
'organization': doc.get('organization', 'N/A') |
|
|
} |
|
|
|
|
|
|
|
|
if 'rerank_score' in doc: |
|
|
source_info['score'] = doc['rerank_score'] |
|
|
source_info['score_type'] = 'rerank' |
|
|
elif 'hybrid_score' in doc: |
|
|
source_info['score'] = doc['hybrid_score'] |
|
|
source_info['score_type'] = 'hybrid' |
|
|
elif 'relevance_score' in doc: |
|
|
source_info['score'] = doc['relevance_score'] |
|
|
source_info['score_type'] = 'embedding' |
|
|
else: |
|
|
source_info['score'] = 0 |
|
|
source_info['score_type'] = 'unknown' |
|
|
|
|
|
sources.append(source_info) |
|
|
return sources |
|
|
|
|
|
@traceable( |
|
|
name="RAG_Generate_Answer", |
|
|
metadata={"component": "generator", "version": "2.0"} |
|
|
) |
|
|
def generate_answer( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = None, |
|
|
search_mode: str = None, |
|
|
alpha: float = None |
|
|
) -> dict: |
|
|
""" |
|
|
λ΅λ³ μμ± (Chain κΈ°λ°) |
|
|
|
|
|
Args: |
|
|
query: μ§λ¬Έ |
|
|
top_k: κ²μν λ¬Έμ μ |
|
|
search_mode: κ²μ λͺ¨λ ("embedding", "hybrid", "hybrid_rerank") |
|
|
alpha: μλ² λ© κ°μ€μΉ (0~1) |
|
|
|
|
|
Returns: |
|
|
dict: answer, sources, search_mode, usage |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
self.top_k = top_k |
|
|
if search_mode is not None: |
|
|
self.search_mode = search_mode |
|
|
if alpha is not None: |
|
|
self.alpha = alpha |
|
|
|
|
|
|
|
|
answer = self.chain.invoke(query) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
|
|
|
self.chat_history.append({"role": "user", "content": query}) |
|
|
self.chat_history.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
|
|
|
estimated_tokens = len(query.split()) + len(answer.split()) * 2 |
|
|
|
|
|
return { |
|
|
'answer': answer, |
|
|
'sources': self._format_sources(self._last_retrieved_docs), |
|
|
'search_mode': self.search_mode, |
|
|
'elapsed_time': elapsed_time, |
|
|
'usage': { |
|
|
'total_tokens': estimated_tokens, |
|
|
'prompt_tokens': 0, |
|
|
'completion_tokens': 0 |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β λ΅λ³ μμ± μ€ν¨: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise RuntimeError(f"λ΅λ³ μμ± μ€ν¨: {str(e)}") from e |
|
|
|
|
|
def chat(self, query: str) -> str: |
|
|
""" |
|
|
κ°λ¨ν λν μΈν°νμ΄μ€ |
|
|
|
|
|
Args: |
|
|
query: μ§λ¬Έ |
|
|
|
|
|
Returns: |
|
|
str: λ΅λ³ ν
μ€νΈλ§ λ°ν |
|
|
""" |
|
|
result = self.generate_answer(query) |
|
|
return result['answer'] |
|
|
|
|
|
def clear_history(self): |
|
|
"""λν νμ€ν 리 μ΄κΈ°ν""" |
|
|
self.chat_history = [] |
|
|
print("ποΈ λν νμ€ν λ¦¬κ° μ΄κΈ°νλμμ΅λλ€.") |
|
|
|
|
|
def get_history(self) -> List[Dict]: |
|
|
"""λν νμ€ν 리 λ°ν""" |
|
|
return self.chat_history.copy() |
|
|
|
|
|
def set_search_config(self, search_mode: str = None, top_k: int = None, alpha: float = None): |
|
|
"""κ²μ μ€μ λ³κ²½""" |
|
|
if search_mode is not None: |
|
|
self.search_mode = search_mode |
|
|
if top_k is not None: |
|
|
self.top_k = top_k |
|
|
if alpha is not None: |
|
|
self.alpha = alpha |
|
|
|
|
|
print(f"π§ κ²μ μ€μ λ³κ²½: mode={self.search_mode}, top_k={self.top_k}, alpha={self.alpha}") |
|
|
|
|
|
def print_result(self, result: dict, query: str = None): |
|
|
"""κ²°κ³Ό μΆλ ₯""" |
|
|
print("\n" + "="*60) |
|
|
if query: |
|
|
print(f"μ§λ¬Έ: {query}") |
|
|
print(f"κ²μ λͺ¨λ: {result.get('search_mode', 'N/A')}") |
|
|
if 'elapsed_time' in result: |
|
|
print(f"μμ μκ°: {result['elapsed_time']:.2f}μ΄") |
|
|
print("="*60) |
|
|
print(f"\nπ¬ λ΅λ³:\n{result['answer']}") |
|
|
print(f"\nπ μ°Έκ³ λ¬Έμ ({len(result['sources'])}κ°):") |
|
|
for i, source in enumerate(result['sources'], 1): |
|
|
score = source.get('score', 0) |
|
|
score_type = source.get('score_type', '') |
|
|
print(f" [{i}] {source['filename']}") |
|
|
print(f" μ μ: {score:.3f} ({score_type})") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
|
|
|
def interactive_mode(): |
|
|
"""λνν λͺ¨λ μ€ν""" |
|
|
print("=" * 60) |
|
|
print("λνν RAG μμ€ν
μ΄κΈ°ν μ€...") |
|
|
print("=" * 60) |
|
|
|
|
|
config = RAGConfig() |
|
|
pipeline = RAGPipeline(config=config) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("λνν λͺ¨λ μμ") |
|
|
print("λͺ
λ Ήμ΄: 'quit' (μ’
λ£), 'clear' (νμ€ν 리 μ΄κΈ°ν), 'mode' (κ²μλͺ¨λ λ³κ²½)") |
|
|
print("=" * 60) |
|
|
|
|
|
while True: |
|
|
user_query = input("\nμ§λ¬Έ: ").strip() |
|
|
|
|
|
if not user_query: |
|
|
continue |
|
|
|
|
|
if user_query.lower() in ['quit', 'exit', 'μ’
λ£', 'q']: |
|
|
print("μμ€ν
μ μ’
λ£ν©λλ€.") |
|
|
break |
|
|
|
|
|
if user_query.lower() == 'clear': |
|
|
pipeline.clear_history() |
|
|
continue |
|
|
|
|
|
if user_query.lower() == 'mode': |
|
|
print("\nκ²μ λͺ¨λ μ ν:") |
|
|
print("1. embedding - μλ² λ© κ²μ") |
|
|
print("2. hybrid - BM25 + μλ² λ©") |
|
|
print("3. hybrid_rerank - Hybrid + Re-ranker (κΆμ₯)") |
|
|
choice = input("μ ν (1/2/3): ").strip() |
|
|
modes = {'1': 'embedding', '2': 'hybrid', '3': 'hybrid_rerank'} |
|
|
if choice in modes: |
|
|
pipeline.set_search_config(search_mode=modes[choice]) |
|
|
continue |
|
|
|
|
|
try: |
|
|
result = pipeline.generate_answer(query=user_query) |
|
|
pipeline.print_result(result, user_query) |
|
|
|
|
|
|
|
|
show_source = input("\nμ°Έμ‘° λ¬Έμ μμΈ λ³΄κΈ°? (y/n): ").strip().lower() |
|
|
if show_source == 'y': |
|
|
for i, source in enumerate(result['sources'], 1): |
|
|
print(f"\n{'='*40}") |
|
|
print(f"[λ¬Έμ {i}] {source['filename']}") |
|
|
print(f"λ°μ£ΌκΈ°κ΄: {source['organization']}") |
|
|
print(f"λ΄μ©:\n{source['content'][:500]}...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β μ€λ₯ λ°μ: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interactive_mode() |
|
|
|