tktm8's picture
Update src/qa/chain.py
069f493 verified
"""
RAGベースの質問応答用QA Chainモジュール
"""
import os
import logging
from pathlib import Path
from typing import List, Optional, Tuple
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY
logger = logging.getLogger(__name__)
class QAChain:
"""RAGベースの質問応答用QA Chain"""
def __init__(
self,
persist_dir: str = "data/vector_store",
model_name: str = "text-embedding-3-small",
k: int = 10,
verbose: bool = False,
llm_model: str = "gpt-4o-mini",
llm_temperature: float = 0.3,
llm_max_tokens: Optional[int] = None,
max_history_turns: int = 3,
max_history_chars: int = 10000
):
"""
永続化されたベクトルストアを使ってQA Chainを初期化
Args:
persist_dir: ベクトルストアの保存ディレクトリ
model_name: 埋め込みモデル名
k: 検索する文書の数
verbose: 詳細ログの出力
llm_model: 使用するLLMモデル名
llm_temperature: LLMの温度パラメーター(0-2)
llm_max_tokens: 最大トークン数(Noneで自動)
max_history_turns: 保持する最大会話ターン数(デフォルト: 10)
max_history_chars: 履歴の最大文字数(デフォルト: 10000)
"""
self.persist_dir = persist_dir
self.model_name = model_name
self.k = k
self.verbose = verbose
self.llm_model = llm_model
self.llm_temperature = llm_temperature
self.llm_max_tokens = llm_max_tokens
self.max_history_turns = max_history_turns
self.max_history_chars = max_history_chars
self.conversation_history = [] # 配列形式で管理
# 環境変数の読み込み(より堅牢な実装)
try:
load_dotenv(dotenv_path=".env")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません")
except FileNotFoundError:
if self.verbose:
logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEYが環境変数に設定されていません")
except Exception as e:
logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}")
raise
self.db = None
self.retriever = None
self.rag_chain = None
self.model_with_history = None
self._setup_chain()
def _load_vector_store(self) -> Chroma:
"""永続化されたベクトルストアを読み込む"""
persist_path = Path(self.persist_dir)
if not persist_path.exists():
raise FileNotFoundError(
f"ベクトルストアが見つかりません: {self.persist_dir}\n"
"まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build"
)
embeddings = OpenAIEmbeddings(model=self.model_name)
db = Chroma(
persist_directory=self.persist_dir,
embedding_function=embeddings
)
if self.verbose:
logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました")
return db
@staticmethod
def _format_docs_with_metadata(docs: List[Document]) -> str:
"""文書をメタデータ付きでコンテキスト用に整形"""
return '\n\n'.join(
f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]"
f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}"
f"\n内容: {doc.page_content}\n---"
for doc in docs
)
def _setup_chain(self):
"""RAGチェーン全体をセットアップ"""
try:
self.db = self._load_vector_store()
self.retriever = self.db.as_retriever(
search_type='similarity',
search_kwargs={'k': self.k}
)
# LLMのパラメーターを動的に構成
llm_params = {
'model': self.llm_model,
}
if self.llm_model == 'gpt-5-nano':
llm_params['temperature'] = 1.0
llm_params['reasoning_effort'] = 'minimal'
llm_params['verbosity'] = 'low'
else:
# その他のモデルでは指定された temperature を使用
if self.llm_temperature != 0.3:
llm_params['temperature'] = self.llm_temperature
if self.llm_max_tokens is not None:
llm_params['max_tokens'] = self.llm_max_tokens
model = ChatOpenAI(**llm_params)
format_docs = RunnableLambda(self._format_docs_with_metadata)
qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE)
# 通常のRAGチェーン
self.rag_chain = {
'context': self.retriever | format_docs,
'question': RunnablePassthrough(),
} | qa_prompt | model | StrOutputParser()
# モデルを保存
self.model = model
self.model_with_history = model
self.format_docs_with_history = format_docs
if self.verbose:
logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!")
except Exception as e:
logger.error(f"QAチェーンのセットアップ中にエラー: {e}")
raise
def ask(self, question: str) -> Tuple[str, List[Document]]:
"""
質問を投げて、回答と参照文書を取得
"""
if not self.model:
raise RuntimeError("モデルが正しく初期化されていません")
try:
# 文書を検索して整形
source_docs = self.retriever.invoke(question)
context = self._format_docs_with_metadata(source_docs)
# URLリストを作成(重複を除去)
source_urls = []
for doc in source_docs:
url = doc.metadata.get('source_url', '')
if url and url not in source_urls:
source_urls.append(url)
urls_text = '\n'.join(f"- {url}" for url in source_urls)
# プロンプトを構築して実行
prompt_input = {
'context': context,
'question': question,
'source_urls': urls_text
}
qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE)
# プロンプトテンプレートを適用して回答を生成
messages = qa_prompt.invoke(prompt_input)
answer = self.model.invoke(messages).content
return answer, source_docs
except Exception as e:
logger.error(f"質問処理中にエラー: {e}")
raise
def _manage_history_window(self):
"""
Sliding Windowを使用して履歴を管理
最大ターン数と最大文字数の両方を考慮
"""
# ターン数の制限
if len(self.conversation_history) > self.max_history_turns:
self.conversation_history = self.conversation_history[-self.max_history_turns:]
# 文字数の制限(古い会話から削除)
total_chars = sum(len(turn) for turn in self.conversation_history)
while total_chars > self.max_history_chars and len(self.conversation_history) > 1:
removed = self.conversation_history.pop(0)
total_chars -= len(removed)
if self.verbose:
logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})")
def _format_history_text(self) -> str:
"""
会話履歴配列を文字列に整形
"""
if not self.conversation_history:
return "まだ会話履歴はありません"
return "\n".join(self.conversation_history)
def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]:
"""
対話履歴を考慮した質問応答(Sliding Window機能付き)
Args:
question: 質問内容
retry_count: リトライ回数(内部使用)
"""
if not self.model_with_history:
raise RuntimeError("履歴付きモデルが正しく初期化されていません")
try:
# 履歴を文字列形式に変換
history_text = self._format_history_text()
# 文書を検索して整形
source_docs = self.retriever.invoke(question)
context = self._format_docs_with_metadata(source_docs)
# URLリストを作成(重複を除去)
source_urls = []
for doc in source_docs:
url = doc.metadata.get('source_url', '')
if url and url not in source_urls:
source_urls.append(url)
urls_text = '\n'.join(f"- {url}" for url in source_urls)
# プロンプトを手動で構築して実行
prompt_input = {
'context': context,
'question': question,
'conversation_history': history_text,
'source_urls': urls_text
}
qa_prompt_with_history = ChatPromptTemplate.from_messages([
("user", "context: {context}"),
("system", CHARACTER_TEMPLATE),
("system", QA_TEMPLATE_WITH_HISTORY),
("user", "質問: {question}")
])
# プロンプトテンプレートを適用して回答を生成
messages = qa_prompt_with_history.invoke(prompt_input)
answer = self.model_with_history.invoke(messages).content
# 新しい会話を履歴に追加
new_turn = f"ユーザー: {question}\nKurageSan®: {answer}"
self.conversation_history.append(new_turn)
# Sliding Windowを適用
self._manage_history_window()
if self.verbose:
total_chars = sum(len(turn) for turn in self.conversation_history)
logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}")
return answer, source_docs
except Exception as e:
# トークン制限エラーの処理
error_message = str(e).lower()
if retry_count < 2 and ('maximum context length' in error_message or
'token' in error_message and 'limit' in error_message):
logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})")
# 履歴を半分に削減
if len(self.conversation_history) > 1:
old_size = len(self.conversation_history)
self.conversation_history = self.conversation_history[old_size//2:]
logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン")
else:
# 履歴が1つ以下の場合はクリア
self.conversation_history = []
logger.info("会話履歴を完全にクリアしました")
# リトライ
return self.ask_with_history(question, retry_count + 1)
logger.error(f"履歴付き質問処理中にエラー: {e}")
raise
def clear_history(self):
"""
対話履歴をクリア
"""
self.conversation_history = []
if self.verbose:
logger.info("対話履歴をクリアしました")
def get_history(self) -> str:
"""
現在の対話履歴を取得(デバッグ用)
"""
return self._format_history_text()
def search_similar(self, query: str, k: int = 5) -> List[Document]:
"""
類似文書を検索
"""
if not self.retriever:
raise RuntimeError("リトリーバーが正しく初期化されていません")
self.retriever.search_kwargs['k'] = k
return self.retriever.invoke(query)
def ask_question(question: str, persist_dir: str = "data/vector_store") -> None:
"""
質問を投げて、回答と参照情報を表示(後方互換用関数)
"""
try:
qa_chain = QAChain(persist_dir=persist_dir, verbose=True)
answer, source_docs = qa_chain.ask(question)
print(f"\n{'='*50}")
print(f"質問: {question}")
print(f"{'='*50}")
print(f"\n回答:\n{answer}")
print(f"\n参考にした文書 ({len(source_docs)}件):")
for i, doc in enumerate(source_docs, 1):
print(f"\n{i}. {doc.metadata.get('title', 'No Title')}")
print(f" URL: {doc.metadata.get('source_url', 'N/A')}")
print(f" チャンク: {doc.metadata.get('chunk_index', 'N/A')}")
except Exception as e:
print(f"エラー: {e}")