|
|
""" |
|
|
RAGベースの質問応答用QA Chainモジュール |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
from langchain_chroma import Chroma |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough |
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
|
|
|
from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class QAChain: |
|
|
"""RAGベースの質問応答用QA Chain""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
persist_dir: str = "data/vector_store", |
|
|
model_name: str = "text-embedding-3-small", |
|
|
k: int = 15, |
|
|
verbose: bool = False, |
|
|
llm_model: str = "gpt-4o-mini", |
|
|
llm_temperature: float = 0.3, |
|
|
llm_max_tokens: Optional[int] = None, |
|
|
max_history_turns: int = 10, |
|
|
max_history_chars: int = 10000 |
|
|
): |
|
|
""" |
|
|
永続化されたベクトルストアを使ってQA Chainを初期化 |
|
|
|
|
|
Args: |
|
|
persist_dir: ベクトルストアの保存ディレクトリ |
|
|
model_name: 埋め込みモデル名 |
|
|
k: 検索する文書の数 |
|
|
verbose: 詳細ログの出力 |
|
|
llm_model: 使用するLLMモデル名 |
|
|
llm_temperature: LLMの温度パラメーター(0-2) |
|
|
llm_max_tokens: 最大トークン数(Noneで自動) |
|
|
max_history_turns: 保持する最大会話ターン数(デフォルト: 10) |
|
|
max_history_chars: 履歴の最大文字数(デフォルト: 10000) |
|
|
""" |
|
|
self.persist_dir = persist_dir |
|
|
self.model_name = model_name |
|
|
self.k = k |
|
|
self.verbose = verbose |
|
|
self.llm_model = llm_model |
|
|
self.llm_temperature = llm_temperature |
|
|
self.llm_max_tokens = llm_max_tokens |
|
|
self.max_history_turns = max_history_turns |
|
|
self.max_history_chars = max_history_chars |
|
|
self.conversation_history = [] |
|
|
|
|
|
|
|
|
try: |
|
|
load_dotenv(dotenv_path=".env") |
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません") |
|
|
except FileNotFoundError: |
|
|
if self.verbose: |
|
|
logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。") |
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("OPENAI_API_KEYが環境変数に設定されていません") |
|
|
except Exception as e: |
|
|
logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}") |
|
|
raise |
|
|
|
|
|
self.db = None |
|
|
self.retriever = None |
|
|
self.rag_chain = None |
|
|
self.model_with_history = None |
|
|
self._setup_chain() |
|
|
|
|
|
def _load_vector_store(self) -> Chroma: |
|
|
"""永続化されたベクトルストアを読み込む""" |
|
|
persist_path = Path(self.persist_dir) |
|
|
if not persist_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"ベクトルストアが見つかりません: {self.persist_dir}\n" |
|
|
"まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build" |
|
|
) |
|
|
embeddings = OpenAIEmbeddings(model=self.model_name) |
|
|
db = Chroma( |
|
|
persist_directory=self.persist_dir, |
|
|
embedding_function=embeddings |
|
|
) |
|
|
if self.verbose: |
|
|
logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました") |
|
|
return db |
|
|
|
|
|
@staticmethod |
|
|
def _format_docs_with_metadata(docs: List[Document]) -> str: |
|
|
"""文書をメタデータ付きでコンテキスト用に整形""" |
|
|
return '\n\n'.join( |
|
|
f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]" |
|
|
f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}" |
|
|
f"\n内容: {doc.page_content}\n---" |
|
|
for doc in docs |
|
|
) |
|
|
|
|
|
def _setup_chain(self): |
|
|
"""RAGチェーン全体をセットアップ""" |
|
|
try: |
|
|
self.db = self._load_vector_store() |
|
|
self.retriever = self.db.as_retriever( |
|
|
search_type='similarity', |
|
|
search_kwargs={'k': self.k} |
|
|
) |
|
|
|
|
|
|
|
|
llm_params = { |
|
|
'model': self.llm_model, |
|
|
} |
|
|
if self.llm_model == 'gpt-5-nano': |
|
|
llm_params['temperature'] = 1.0 |
|
|
llm_params['reasoning_effort'] = 'minimal' |
|
|
llm_params['verbosity'] = 'low' |
|
|
else: |
|
|
|
|
|
if self.llm_temperature != 0.3: |
|
|
llm_params['temperature'] = self.llm_temperature |
|
|
|
|
|
if self.llm_max_tokens is not None: |
|
|
llm_params['max_tokens'] = self.llm_max_tokens |
|
|
|
|
|
model = ChatOpenAI(**llm_params) |
|
|
format_docs = RunnableLambda(self._format_docs_with_metadata) |
|
|
qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) |
|
|
|
|
|
|
|
|
self.rag_chain = { |
|
|
'context': self.retriever | format_docs, |
|
|
'question': RunnablePassthrough(), |
|
|
} | qa_prompt | model | StrOutputParser() |
|
|
|
|
|
|
|
|
self.model = model |
|
|
self.model_with_history = model |
|
|
self.format_docs_with_history = format_docs |
|
|
|
|
|
if self.verbose: |
|
|
logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!") |
|
|
except Exception as e: |
|
|
logger.error(f"QAチェーンのセットアップ中にエラー: {e}") |
|
|
raise |
|
|
|
|
|
def ask(self, question: str) -> Tuple[str, List[Document]]: |
|
|
""" |
|
|
質問を投げて、回答と参照文書を取得 |
|
|
""" |
|
|
if not self.model: |
|
|
raise RuntimeError("モデルが正しく初期化されていません") |
|
|
try: |
|
|
|
|
|
source_docs = self.retriever.invoke(question) |
|
|
context = self._format_docs_with_metadata(source_docs) |
|
|
|
|
|
|
|
|
source_urls = [] |
|
|
for doc in source_docs: |
|
|
url = doc.metadata.get('source_url', '') |
|
|
if url and url not in source_urls: |
|
|
source_urls.append(url) |
|
|
urls_text = '\n'.join(f"- {url}" for url in source_urls) |
|
|
|
|
|
|
|
|
prompt_input = { |
|
|
'context': context, |
|
|
'question': question, |
|
|
'source_urls': urls_text |
|
|
} |
|
|
qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) |
|
|
|
|
|
|
|
|
messages = qa_prompt.invoke(prompt_input) |
|
|
answer = self.model.invoke(messages).content |
|
|
|
|
|
return answer, source_docs |
|
|
except Exception as e: |
|
|
logger.error(f"質問処理中にエラー: {e}") |
|
|
raise |
|
|
|
|
|
def _manage_history_window(self): |
|
|
""" |
|
|
Sliding Windowを使用して履歴を管理 |
|
|
最大ターン数と最大文字数の両方を考慮 |
|
|
""" |
|
|
|
|
|
if len(self.conversation_history) > self.max_history_turns: |
|
|
self.conversation_history = self.conversation_history[-self.max_history_turns:] |
|
|
|
|
|
|
|
|
total_chars = sum(len(turn) for turn in self.conversation_history) |
|
|
while total_chars > self.max_history_chars and len(self.conversation_history) > 1: |
|
|
removed = self.conversation_history.pop(0) |
|
|
total_chars -= len(removed) |
|
|
if self.verbose: |
|
|
logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})") |
|
|
|
|
|
def _format_history_text(self) -> str: |
|
|
""" |
|
|
会話履歴配列を文字列に整形 |
|
|
""" |
|
|
if not self.conversation_history: |
|
|
return "まだ会話履歴はありません" |
|
|
return "\n".join(self.conversation_history) |
|
|
|
|
|
def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]: |
|
|
""" |
|
|
対話履歴を考慮した質問応答(Sliding Window機能付き) |
|
|
|
|
|
Args: |
|
|
question: 質問内容 |
|
|
retry_count: リトライ回数(内部使用) |
|
|
""" |
|
|
if not self.model_with_history: |
|
|
raise RuntimeError("履歴付きモデルが正しく初期化されていません") |
|
|
|
|
|
try: |
|
|
|
|
|
history_text = self._format_history_text() |
|
|
|
|
|
|
|
|
source_docs = self.retriever.invoke(question) |
|
|
context = self._format_docs_with_metadata(source_docs) |
|
|
|
|
|
|
|
|
source_urls = [] |
|
|
for doc in source_docs: |
|
|
url = doc.metadata.get('source_url', '') |
|
|
if url and url not in source_urls: |
|
|
source_urls.append(url) |
|
|
urls_text = '\n'.join(f"- {url}" for url in source_urls) |
|
|
|
|
|
|
|
|
prompt_input = { |
|
|
'context': context, |
|
|
'question': question, |
|
|
'conversation_history': history_text, |
|
|
'source_urls': urls_text |
|
|
} |
|
|
qa_prompt_with_history = ChatPromptTemplate.from_messages([ |
|
|
("system", CHARACTER_TEMPLATE), |
|
|
("system", QA_TEMPLATE_WITH_HISTORY), |
|
|
]) |
|
|
|
|
|
|
|
|
messages = qa_prompt_with_history.invoke(prompt_input) |
|
|
answer = self.model_with_history.invoke(messages).content |
|
|
|
|
|
|
|
|
new_turn = f"ユーザー: {question}\nKurageSan®: {answer}" |
|
|
self.conversation_history.append(new_turn) |
|
|
|
|
|
|
|
|
self._manage_history_window() |
|
|
|
|
|
if self.verbose: |
|
|
total_chars = sum(len(turn) for turn in self.conversation_history) |
|
|
logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}") |
|
|
|
|
|
return answer, source_docs |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
error_message = str(e).lower() |
|
|
if retry_count < 2 and ('maximum context length' in error_message or |
|
|
'token' in error_message and 'limit' in error_message): |
|
|
logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})") |
|
|
|
|
|
|
|
|
if len(self.conversation_history) > 1: |
|
|
old_size = len(self.conversation_history) |
|
|
self.conversation_history = self.conversation_history[old_size//2:] |
|
|
logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン") |
|
|
else: |
|
|
|
|
|
self.conversation_history = [] |
|
|
logger.info("会話履歴を完全にクリアしました") |
|
|
|
|
|
|
|
|
return self.ask_with_history(question, retry_count + 1) |
|
|
|
|
|
logger.error(f"履歴付き質問処理中にエラー: {e}") |
|
|
raise |
|
|
|
|
|
def clear_history(self): |
|
|
""" |
|
|
対話履歴をクリア |
|
|
""" |
|
|
self.conversation_history = [] |
|
|
if self.verbose: |
|
|
logger.info("対話履歴をクリアしました") |
|
|
|
|
|
def get_history(self) -> str: |
|
|
""" |
|
|
現在の対話履歴を取得(デバッグ用) |
|
|
""" |
|
|
return self._format_history_text() |
|
|
|
|
|
def search_similar(self, query: str, k: int = 5) -> List[Document]: |
|
|
""" |
|
|
類似文書を検索 |
|
|
""" |
|
|
if not self.retriever: |
|
|
raise RuntimeError("リトリーバーが正しく初期化されていません") |
|
|
self.retriever.search_kwargs['k'] = k |
|
|
return self.retriever.invoke(query) |
|
|
|
|
|
|
|
|
def ask_question(question: str, persist_dir: str = "data/vector_store") -> None: |
|
|
""" |
|
|
質問を投げて、回答と参照情報を表示(後方互換用関数) |
|
|
""" |
|
|
try: |
|
|
qa_chain = QAChain(persist_dir=persist_dir, verbose=True) |
|
|
answer, source_docs = qa_chain.ask(question) |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"質問: {question}") |
|
|
print(f"{'='*50}") |
|
|
print(f"\n回答:\n{answer}") |
|
|
|
|
|
print(f"\n参考にした文書 ({len(source_docs)}件):") |
|
|
for i, doc in enumerate(source_docs, 1): |
|
|
print(f"\n{i}. {doc.metadata.get('title', 'No Title')}") |
|
|
print(f" URL: {doc.metadata.get('source_url', 'N/A')}") |
|
|
print(f" チャンク: {doc.metadata.get('chunk_index', 'N/A')}") |
|
|
except Exception as e: |
|
|
print(f"エラー: {e}") |
|
|
|