| """ |
| RAGベースの質問応答用QA Chainモジュール |
| """ |
|
|
| import os |
| import logging |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| from dotenv import load_dotenv |
| from langchain_chroma import Chroma |
| from langchain_core.documents import Document |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
| from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class QAChain: |
| """RAGベースの質問応答用QA Chain""" |
|
|
| def __init__( |
| self, |
| persist_dir: str = "data/vector_store", |
| model_name: str = "text-embedding-3-small", |
| k: int = 5, |
| verbose: bool = False, |
| llm_model: str = "gpt-4o-mini", |
| llm_temperature: Optional[float] = 0.3, |
| llm_max_tokens: Optional[int] = None, |
| max_history_turns: int = 5, |
| max_history_chars: int = 10000 |
| ): |
| """ |
| 永続化されたベクトルストアを使ってQA Chainを初期化 |
| |
| Args: |
| persist_dir: ベクトルストアの保存ディレクトリ |
| model_name: 埋め込みモデル名 |
| k: 検索する文書の数 |
| verbose: 詳細ログの出力 |
| llm_model: 使用するLLMモデル名 |
| llm_temperature: LLMの温度パラメーター(0-2) |
| llm_max_tokens: 最大トークン数(Noneで自動) |
| max_history_turns: 保持する最大会話ターン数(デフォルト: 10) |
| max_history_chars: 履歴の最大文字数(デフォルト: 10000) |
| """ |
| self.persist_dir = persist_dir |
| self.model_name = model_name |
| self.k = k |
| self.verbose = verbose |
| self.llm_model = llm_model |
| self.llm_temperature = llm_temperature |
| self.llm_max_tokens = llm_max_tokens |
| self.max_history_turns = max_history_turns |
| self.max_history_chars = max_history_chars |
| self.conversation_history = [] |
|
|
| |
| try: |
| load_dotenv(dotenv_path=".env") |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません") |
| except FileNotFoundError: |
| if self.verbose: |
| logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。") |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| raise ValueError("OPENAI_API_KEYが環境変数に設定されていません") |
| except Exception as e: |
| logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}") |
| raise |
|
|
| self.db = None |
| self.retriever = None |
| self.rag_chain = None |
| self.model_with_history = None |
| self._setup_chain() |
|
|
| def _load_vector_store(self) -> Chroma: |
| """永続化されたベクトルストアを読み込む""" |
| persist_path = Path(self.persist_dir) |
| if not persist_path.exists(): |
| raise FileNotFoundError( |
| f"ベクトルストアが見つかりません: {self.persist_dir}\n" |
| "まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build" |
| ) |
| embeddings = OpenAIEmbeddings(model=self.model_name) |
| db = Chroma( |
| persist_directory=self.persist_dir, |
| embedding_function=embeddings |
| ) |
| if self.verbose: |
| logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました") |
| return db |
|
|
| @staticmethod |
| def _format_docs_with_metadata(docs: List[Document]) -> str: |
| """文書をメタデータ付きでコンテキスト用に整形""" |
| return '\n\n'.join( |
| f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]" |
| f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}" |
| f"\n内容: {doc.page_content}\n---" |
| for doc in docs |
| ) |
|
|
| def _setup_chain(self): |
| """RAGチェーン全体をセットアップ""" |
| try: |
| self.db = self._load_vector_store() |
| self.retriever = self.db.as_retriever( |
| search_type='similarity', |
| search_kwargs={'k': self.k} |
| ) |
| |
| |
| llm_params = { |
| 'model': self.llm_model, |
| } |
| if self.llm_model == 'gpt-5-nano': |
| llm_params['temperature'] = 1.0 |
| llm_params['reasoning_effort'] = 'high' |
| llm_params['verbosity'] = 'low' |
| elif self.llm_model == 'gpt-5-mini': |
| |
| llm_params['temperature'] = 1.0 |
| else: |
| |
| if self.llm_temperature != 0.3: |
| llm_params['temperature'] = self.llm_temperature |
| |
| if self.llm_max_tokens is not None: |
| llm_params['max_tokens'] = self.llm_max_tokens |
| |
| model = ChatOpenAI(**llm_params) |
| format_docs = RunnableLambda(self._format_docs_with_metadata) |
| qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) |
| |
| |
| self.rag_chain = { |
| 'context': self.retriever | format_docs, |
| 'question': RunnablePassthrough(), |
| } | qa_prompt | model | StrOutputParser() |
|
|
| |
| self.model = model |
| self.model_with_history = model |
| self.format_docs_with_history = format_docs |
|
|
| if self.verbose: |
| logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!") |
| except Exception as e: |
| logger.error(f"QAチェーンのセットアップ中にエラー: {e}") |
| raise |
|
|
| def ask(self, question: str) -> Tuple[str, List[Document]]: |
| """ |
| 質問を投げて、回答と参照文書を取得 |
| """ |
| if not self.model: |
| raise RuntimeError("モデルが正しく初期化されていません") |
| try: |
| |
| source_docs = self.retriever.invoke(question) |
| context = self._format_docs_with_metadata(source_docs) |
| |
| |
| source_urls = [] |
| for doc in source_docs: |
| url = doc.metadata.get('source_url', '') |
| if url and url not in source_urls: |
| source_urls.append(url) |
| urls_text = '\n'.join(f"- {url}" for url in source_urls) |
| |
| |
| prompt_input = { |
| 'context': context, |
| 'question': question, |
| 'source_urls': urls_text |
| } |
| qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) |
| |
| |
| messages = qa_prompt.invoke(prompt_input) |
| answer = self.model.invoke(messages).content |
| |
| return answer, source_docs |
| except Exception as e: |
| logger.error(f"質問処理中にエラー: {e}") |
| raise |
|
|
| def _manage_history_window(self): |
| """ |
| Sliding Windowを使用して履歴を管理 |
| 最大ターン数と最大文字数の両方を考慮 |
| """ |
| |
| if len(self.conversation_history) > self.max_history_turns: |
| self.conversation_history = self.conversation_history[-self.max_history_turns:] |
| |
| |
| total_chars = sum(len(turn) for turn in self.conversation_history) |
| while total_chars > self.max_history_chars and len(self.conversation_history) > 1: |
| removed = self.conversation_history.pop(0) |
| total_chars -= len(removed) |
| if self.verbose: |
| logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})") |
| |
| def _format_history_text(self) -> str: |
| """ |
| 会話履歴配列を文字列に整形 |
| """ |
| if not self.conversation_history: |
| return "まだ会話履歴はありません" |
| return "\n".join(self.conversation_history) |
|
|
| def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]: |
| """ |
| 対話履歴を考慮した質問応答(Sliding Window機能付き) |
| |
| Args: |
| question: 質問内容 |
| retry_count: リトライ回数(内部使用) |
| """ |
| if not self.model_with_history: |
| raise RuntimeError("履歴付きモデルが正しく初期化されていません") |
| |
| try: |
| |
| history_text = self._format_history_text() |
| |
| |
| source_docs = self.retriever.invoke(question) |
| context = self._format_docs_with_metadata(source_docs) |
| |
| |
| source_urls = [] |
| for doc in source_docs: |
| url = doc.metadata.get('source_url', '') |
| if url and url not in source_urls: |
| source_urls.append(url) |
| urls_text = '\n'.join(f"- {url}" for url in source_urls) |
| |
| |
| prompt_input = { |
| 'context': context, |
| 'question': question, |
| 'conversation_history': history_text, |
| 'source_urls': urls_text |
| } |
| qa_prompt_with_history = ChatPromptTemplate.from_messages([ |
| ("user", "context: {context}"), |
| ("system", CHARACTER_TEMPLATE), |
| ("system", QA_TEMPLATE_WITH_HISTORY), |
| ("user", "質問: {question}") |
| ]) |
| |
| |
| messages = qa_prompt_with_history.invoke(prompt_input) |
| answer = self.model_with_history.invoke(messages).content |
| |
| |
| new_turn = f"ユーザー: {question}\nKurageSan®: {answer}" |
| self.conversation_history.append(new_turn) |
| |
| |
| self._manage_history_window() |
| |
| if self.verbose: |
| total_chars = sum(len(turn) for turn in self.conversation_history) |
| logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}") |
| |
| return answer, source_docs |
| |
| except Exception as e: |
| |
| error_message = str(e).lower() |
| if retry_count < 2 and ('maximum context length' in error_message or |
| 'token' in error_message and 'limit' in error_message): |
| logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})") |
| |
| |
| if len(self.conversation_history) > 1: |
| old_size = len(self.conversation_history) |
| self.conversation_history = self.conversation_history[old_size//2:] |
| logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン") |
| else: |
| |
| self.conversation_history = [] |
| logger.info("会話履歴を完全にクリアしました") |
| |
| |
| return self.ask_with_history(question, retry_count + 1) |
| |
| logger.error(f"履歴付き質問処理中にエラー: {e}") |
| raise |
| |
| def clear_history(self): |
| """ |
| 対話履歴をクリア |
| """ |
| self.conversation_history = [] |
| if self.verbose: |
| logger.info("対話履歴をクリアしました") |
| |
| def get_history(self) -> str: |
| """ |
| 現在の対話履歴を取得(デバッグ用) |
| """ |
| return self._format_history_text() |
|
|
| def search_similar(self, query: str, k: int = 5) -> List[Document]: |
| """ |
| 類似文書を検索 |
| """ |
| if not self.retriever: |
| raise RuntimeError("リトリーバーが正しく初期化されていません") |
| self.retriever.search_kwargs['k'] = k |
| return self.retriever.invoke(query) |
|
|
|
|
| def ask_question(question: str, persist_dir: str = "data/vector_store") -> None: |
| """ |
| 質問を投げて、回答と参照情報を表示(後方互換用関数) |
| """ |
| try: |
| qa_chain = QAChain(persist_dir=persist_dir, verbose=True) |
| answer, source_docs = qa_chain.ask(question) |
|
|
| print(f"\n{'='*50}") |
| print(f"質問: {question}") |
| print(f"{'='*50}") |
| print(f"\n回答:\n{answer}") |
|
|
| print(f"\n参考にした文書 ({len(source_docs)}件):") |
| for i, doc in enumerate(source_docs, 1): |
| print(f"\n{i}. {doc.metadata.get('title', 'No Title')}") |
| print(f" URL: {doc.metadata.get('source_url', 'N/A')}") |
| print(f" チャンク: {doc.metadata.get('chunk_index', 'N/A')}") |
| except Exception as e: |
| print(f"エラー: {e}") |
|
|