""" RAGベースの質問応答用QA Chainモジュール """ import os import logging from pathlib import Path from typing import List, Optional, Tuple from dotenv import load_dotenv from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_openai import OpenAIEmbeddings, ChatOpenAI from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY logger = logging.getLogger(__name__) class QAChain: """RAGベースの質問応答用QA Chain""" def __init__( self, persist_dir: str = "data/vector_store", model_name: str = "text-embedding-3-small", k: int = 10, verbose: bool = False, llm_model: str = "gpt-4o-mini", llm_temperature: float = 0.3, llm_max_tokens: Optional[int] = None, max_history_turns: int = 3, max_history_chars: int = 10000 ): """ 永続化されたベクトルストアを使ってQA Chainを初期化 Args: persist_dir: ベクトルストアの保存ディレクトリ model_name: 埋め込みモデル名 k: 検索する文書の数 verbose: 詳細ログの出力 llm_model: 使用するLLMモデル名 llm_temperature: LLMの温度パラメーター(0-2) llm_max_tokens: 最大トークン数(Noneで自動) max_history_turns: 保持する最大会話ターン数(デフォルト: 10) max_history_chars: 履歴の最大文字数(デフォルト: 10000) """ self.persist_dir = persist_dir self.model_name = model_name self.k = k self.verbose = verbose self.llm_model = llm_model self.llm_temperature = llm_temperature self.llm_max_tokens = llm_max_tokens self.max_history_turns = max_history_turns self.max_history_chars = max_history_chars self.conversation_history = [] # 配列形式で管理 # 環境変数の読み込み(より堅牢な実装) try: load_dotenv(dotenv_path=".env") api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません") except FileNotFoundError: if self.verbose: logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。") api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEYが環境変数に設定されていません") except Exception as e: logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}") raise self.db = None self.retriever = None self.rag_chain = None self.model_with_history = None self._setup_chain() def _load_vector_store(self) -> Chroma: """永続化されたベクトルストアを読み込む""" persist_path = Path(self.persist_dir) if not persist_path.exists(): raise FileNotFoundError( f"ベクトルストアが見つかりません: {self.persist_dir}\n" "まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build" ) embeddings = OpenAIEmbeddings(model=self.model_name) db = Chroma( persist_directory=self.persist_dir, embedding_function=embeddings ) if self.verbose: logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました") return db @staticmethod def _format_docs_with_metadata(docs: List[Document]) -> str: """文書をメタデータ付きでコンテキスト用に整形""" return '\n\n'.join( f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]" f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}" f"\n内容: {doc.page_content}\n---" for doc in docs ) def _setup_chain(self): """RAGチェーン全体をセットアップ""" try: self.db = self._load_vector_store() self.retriever = self.db.as_retriever( search_type='similarity', search_kwargs={'k': self.k} ) # LLMのパラメーターを動的に構成 llm_params = { 'model': self.llm_model, } if self.llm_model == 'gpt-5-nano': llm_params['temperature'] = 1.0 llm_params['reasoning_effort'] = 'minimal' llm_params['verbosity'] = 'low' else: # その他のモデルでは指定された temperature を使用 if self.llm_temperature != 0.3: llm_params['temperature'] = self.llm_temperature if self.llm_max_tokens is not None: llm_params['max_tokens'] = self.llm_max_tokens model = ChatOpenAI(**llm_params) format_docs = RunnableLambda(self._format_docs_with_metadata) qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) # 通常のRAGチェーン self.rag_chain = { 'context': self.retriever | format_docs, 'question': RunnablePassthrough(), } | qa_prompt | model | StrOutputParser() # モデルを保存 self.model = model self.model_with_history = model self.format_docs_with_history = format_docs if self.verbose: logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!") except Exception as e: logger.error(f"QAチェーンのセットアップ中にエラー: {e}") raise def ask(self, question: str) -> Tuple[str, List[Document]]: """ 質問を投げて、回答と参照文書を取得 """ if not self.model: raise RuntimeError("モデルが正しく初期化されていません") try: # 文書を検索して整形 source_docs = self.retriever.invoke(question) context = self._format_docs_with_metadata(source_docs) # URLリストを作成(重複を除去) source_urls = [] for doc in source_docs: url = doc.metadata.get('source_url', '') if url and url not in source_urls: source_urls.append(url) urls_text = '\n'.join(f"- {url}" for url in source_urls) # プロンプトを構築して実行 prompt_input = { 'context': context, 'question': question, 'source_urls': urls_text } qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) # プロンプトテンプレートを適用して回答を生成 messages = qa_prompt.invoke(prompt_input) answer = self.model.invoke(messages).content return answer, source_docs except Exception as e: logger.error(f"質問処理中にエラー: {e}") raise def _manage_history_window(self): """ Sliding Windowを使用して履歴を管理 最大ターン数と最大文字数の両方を考慮 """ # ターン数の制限 if len(self.conversation_history) > self.max_history_turns: self.conversation_history = self.conversation_history[-self.max_history_turns:] # 文字数の制限(古い会話から削除) total_chars = sum(len(turn) for turn in self.conversation_history) while total_chars > self.max_history_chars and len(self.conversation_history) > 1: removed = self.conversation_history.pop(0) total_chars -= len(removed) if self.verbose: logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})") def _format_history_text(self) -> str: """ 会話履歴配列を文字列に整形 """ if not self.conversation_history: return "まだ会話履歴はありません" return "\n".join(self.conversation_history) def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]: """ 対話履歴を考慮した質問応答(Sliding Window機能付き) Args: question: 質問内容 retry_count: リトライ回数(内部使用) """ if not self.model_with_history: raise RuntimeError("履歴付きモデルが正しく初期化されていません") try: # 履歴を文字列形式に変換 history_text = self._format_history_text() # 文書を検索して整形 source_docs = self.retriever.invoke(question) context = self._format_docs_with_metadata(source_docs) # URLリストを作成(重複を除去) source_urls = [] for doc in source_docs: url = doc.metadata.get('source_url', '') if url and url not in source_urls: source_urls.append(url) urls_text = '\n'.join(f"- {url}" for url in source_urls) # プロンプトを手動で構築して実行 prompt_input = { 'context': context, 'question': question, 'conversation_history': history_text, 'source_urls': urls_text } qa_prompt_with_history = ChatPromptTemplate.from_messages([ ("user", "context: {context}"), ("system", CHARACTER_TEMPLATE), ("system", QA_TEMPLATE_WITH_HISTORY), ("user", "質問: {question}") ]) # プロンプトテンプレートを適用して回答を生成 messages = qa_prompt_with_history.invoke(prompt_input) answer = self.model_with_history.invoke(messages).content # 新しい会話を履歴に追加 new_turn = f"ユーザー: {question}\nKurageSan®: {answer}" self.conversation_history.append(new_turn) # Sliding Windowを適用 self._manage_history_window() if self.verbose: total_chars = sum(len(turn) for turn in self.conversation_history) logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}") return answer, source_docs except Exception as e: # トークン制限エラーの処理 error_message = str(e).lower() if retry_count < 2 and ('maximum context length' in error_message or 'token' in error_message and 'limit' in error_message): logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})") # 履歴を半分に削減 if len(self.conversation_history) > 1: old_size = len(self.conversation_history) self.conversation_history = self.conversation_history[old_size//2:] logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン") else: # 履歴が1つ以下の場合はクリア self.conversation_history = [] logger.info("会話履歴を完全にクリアしました") # リトライ return self.ask_with_history(question, retry_count + 1) logger.error(f"履歴付き質問処理中にエラー: {e}") raise def clear_history(self): """ 対話履歴をクリア """ self.conversation_history = [] if self.verbose: logger.info("対話履歴をクリアしました") def get_history(self) -> str: """ 現在の対話履歴を取得(デバッグ用) """ return self._format_history_text() def search_similar(self, query: str, k: int = 5) -> List[Document]: """ 類似文書を検索 """ if not self.retriever: raise RuntimeError("リトリーバーが正しく初期化されていません") self.retriever.search_kwargs['k'] = k return self.retriever.invoke(query) def ask_question(question: str, persist_dir: str = "data/vector_store") -> None: """ 質問を投げて、回答と参照情報を表示(後方互換用関数) """ try: qa_chain = QAChain(persist_dir=persist_dir, verbose=True) answer, source_docs = qa_chain.ask(question) print(f"\n{'='*50}") print(f"質問: {question}") print(f"{'='*50}") print(f"\n回答:\n{answer}") print(f"\n参考にした文書 ({len(source_docs)}件):") for i, doc in enumerate(source_docs, 1): print(f"\n{i}. {doc.metadata.get('title', 'No Title')}") print(f" URL: {doc.metadata.get('source_url', 'N/A')}") print(f" チャンク: {doc.metadata.get('chunk_index', 'N/A')}") except Exception as e: print(f"エラー: {e}")