Spaces:
Sleeping
Sleeping
| """ | |
| RAGベースの質問応答用QA Chainモジュール | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| from dotenv import load_dotenv | |
| from langchain_chroma import Chroma | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY | |
| logger = logging.getLogger(__name__) | |
| class QAChain: | |
| """RAGベースの質問応答用QA Chain""" | |
| def __init__( | |
| self, | |
| persist_dir: str = "data/vector_store", | |
| model_name: str = "text-embedding-3-small", | |
| k: int = 10, | |
| verbose: bool = False, | |
| llm_model: str = "gpt-4o-mini", | |
| llm_temperature: float = 0.3, | |
| llm_max_tokens: Optional[int] = None, | |
| max_history_turns: int = 3, | |
| max_history_chars: int = 10000 | |
| ): | |
| """ | |
| 永続化されたベクトルストアを使ってQA Chainを初期化 | |
| Args: | |
| persist_dir: ベクトルストアの保存ディレクトリ | |
| model_name: 埋め込みモデル名 | |
| k: 検索する文書の数 | |
| verbose: 詳細ログの出力 | |
| llm_model: 使用するLLMモデル名 | |
| llm_temperature: LLMの温度パラメーター(0-2) | |
| llm_max_tokens: 最大トークン数(Noneで自動) | |
| max_history_turns: 保持する最大会話ターン数(デフォルト: 10) | |
| max_history_chars: 履歴の最大文字数(デフォルト: 10000) | |
| """ | |
| self.persist_dir = persist_dir | |
| self.model_name = model_name | |
| self.k = k | |
| self.verbose = verbose | |
| self.llm_model = llm_model | |
| self.llm_temperature = llm_temperature | |
| self.llm_max_tokens = llm_max_tokens | |
| self.max_history_turns = max_history_turns | |
| self.max_history_chars = max_history_chars | |
| self.conversation_history = [] # 配列形式で管理 | |
| # 環境変数の読み込み(より堅牢な実装) | |
| try: | |
| load_dotenv(dotenv_path=".env") | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません") | |
| except FileNotFoundError: | |
| if self.verbose: | |
| logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。") | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEYが環境変数に設定されていません") | |
| except Exception as e: | |
| logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}") | |
| raise | |
| self.db = None | |
| self.retriever = None | |
| self.rag_chain = None | |
| self.model_with_history = None | |
| self._setup_chain() | |
| def _load_vector_store(self) -> Chroma: | |
| """永続化されたベクトルストアを読み込む""" | |
| persist_path = Path(self.persist_dir) | |
| if not persist_path.exists(): | |
| raise FileNotFoundError( | |
| f"ベクトルストアが見つかりません: {self.persist_dir}\n" | |
| "まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build" | |
| ) | |
| embeddings = OpenAIEmbeddings(model=self.model_name) | |
| db = Chroma( | |
| persist_directory=self.persist_dir, | |
| embedding_function=embeddings | |
| ) | |
| if self.verbose: | |
| logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました") | |
| return db | |
| def _format_docs_with_metadata(docs: List[Document]) -> str: | |
| """文書をメタデータ付きでコンテキスト用に整形""" | |
| return '\n\n'.join( | |
| f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]" | |
| f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}" | |
| f"\n内容: {doc.page_content}\n---" | |
| for doc in docs | |
| ) | |
| def _setup_chain(self): | |
| """RAGチェーン全体をセットアップ""" | |
| try: | |
| self.db = self._load_vector_store() | |
| self.retriever = self.db.as_retriever( | |
| search_type='similarity', | |
| search_kwargs={'k': self.k} | |
| ) | |
| # LLMのパラメーターを動的に構成 | |
| llm_params = { | |
| 'model': self.llm_model, | |
| } | |
| if self.llm_model == 'gpt-5-nano': | |
| llm_params['temperature'] = 1.0 | |
| llm_params['reasoning_effort'] = 'minimal' | |
| llm_params['verbosity'] = 'low' | |
| else: | |
| # その他のモデルでは指定された temperature を使用 | |
| if self.llm_temperature != 0.3: | |
| llm_params['temperature'] = self.llm_temperature | |
| if self.llm_max_tokens is not None: | |
| llm_params['max_tokens'] = self.llm_max_tokens | |
| model = ChatOpenAI(**llm_params) | |
| format_docs = RunnableLambda(self._format_docs_with_metadata) | |
| qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) | |
| # 通常のRAGチェーン | |
| self.rag_chain = { | |
| 'context': self.retriever | format_docs, | |
| 'question': RunnablePassthrough(), | |
| } | qa_prompt | model | StrOutputParser() | |
| # モデルを保存 | |
| self.model = model | |
| self.model_with_history = model | |
| self.format_docs_with_history = format_docs | |
| if self.verbose: | |
| logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!") | |
| except Exception as e: | |
| logger.error(f"QAチェーンのセットアップ中にエラー: {e}") | |
| raise | |
| def ask(self, question: str) -> Tuple[str, List[Document]]: | |
| """ | |
| 質問を投げて、回答と参照文書を取得 | |
| """ | |
| if not self.model: | |
| raise RuntimeError("モデルが正しく初期化されていません") | |
| try: | |
| # 文書を検索して整形 | |
| source_docs = self.retriever.invoke(question) | |
| context = self._format_docs_with_metadata(source_docs) | |
| # URLリストを作成(重複を除去) | |
| source_urls = [] | |
| for doc in source_docs: | |
| url = doc.metadata.get('source_url', '') | |
| if url and url not in source_urls: | |
| source_urls.append(url) | |
| urls_text = '\n'.join(f"- {url}" for url in source_urls) | |
| # プロンプトを構築して実行 | |
| prompt_input = { | |
| 'context': context, | |
| 'question': question, | |
| 'source_urls': urls_text | |
| } | |
| qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE) | |
| # プロンプトテンプレートを適用して回答を生成 | |
| messages = qa_prompt.invoke(prompt_input) | |
| answer = self.model.invoke(messages).content | |
| return answer, source_docs | |
| except Exception as e: | |
| logger.error(f"質問処理中にエラー: {e}") | |
| raise | |
| def _manage_history_window(self): | |
| """ | |
| Sliding Windowを使用して履歴を管理 | |
| 最大ターン数と最大文字数の両方を考慮 | |
| """ | |
| # ターン数の制限 | |
| if len(self.conversation_history) > self.max_history_turns: | |
| self.conversation_history = self.conversation_history[-self.max_history_turns:] | |
| # 文字数の制限(古い会話から削除) | |
| total_chars = sum(len(turn) for turn in self.conversation_history) | |
| while total_chars > self.max_history_chars and len(self.conversation_history) > 1: | |
| removed = self.conversation_history.pop(0) | |
| total_chars -= len(removed) | |
| if self.verbose: | |
| logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})") | |
| def _format_history_text(self) -> str: | |
| """ | |
| 会話履歴配列を文字列に整形 | |
| """ | |
| if not self.conversation_history: | |
| return "まだ会話履歴はありません" | |
| return "\n".join(self.conversation_history) | |
| def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]: | |
| """ | |
| 対話履歴を考慮した質問応答(Sliding Window機能付き) | |
| Args: | |
| question: 質問内容 | |
| retry_count: リトライ回数(内部使用) | |
| """ | |
| if not self.model_with_history: | |
| raise RuntimeError("履歴付きモデルが正しく初期化されていません") | |
| try: | |
| # 履歴を文字列形式に変換 | |
| history_text = self._format_history_text() | |
| # 文書を検索して整形 | |
| source_docs = self.retriever.invoke(question) | |
| context = self._format_docs_with_metadata(source_docs) | |
| # URLリストを作成(重複を除去) | |
| source_urls = [] | |
| for doc in source_docs: | |
| url = doc.metadata.get('source_url', '') | |
| if url and url not in source_urls: | |
| source_urls.append(url) | |
| urls_text = '\n'.join(f"- {url}" for url in source_urls) | |
| # プロンプトを手動で構築して実行 | |
| prompt_input = { | |
| 'context': context, | |
| 'question': question, | |
| 'conversation_history': history_text, | |
| 'source_urls': urls_text | |
| } | |
| qa_prompt_with_history = ChatPromptTemplate.from_messages([ | |
| ("user", "context: {context}"), | |
| ("system", CHARACTER_TEMPLATE), | |
| ("system", QA_TEMPLATE_WITH_HISTORY), | |
| ("user", "質問: {question}") | |
| ]) | |
| # プロンプトテンプレートを適用して回答を生成 | |
| messages = qa_prompt_with_history.invoke(prompt_input) | |
| answer = self.model_with_history.invoke(messages).content | |
| # 新しい会話を履歴に追加 | |
| new_turn = f"ユーザー: {question}\nKurageSan®: {answer}" | |
| self.conversation_history.append(new_turn) | |
| # Sliding Windowを適用 | |
| self._manage_history_window() | |
| if self.verbose: | |
| total_chars = sum(len(turn) for turn in self.conversation_history) | |
| logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}") | |
| return answer, source_docs | |
| except Exception as e: | |
| # トークン制限エラーの処理 | |
| error_message = str(e).lower() | |
| if retry_count < 2 and ('maximum context length' in error_message or | |
| 'token' in error_message and 'limit' in error_message): | |
| logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})") | |
| # 履歴を半分に削減 | |
| if len(self.conversation_history) > 1: | |
| old_size = len(self.conversation_history) | |
| self.conversation_history = self.conversation_history[old_size//2:] | |
| logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン") | |
| else: | |
| # 履歴が1つ以下の場合はクリア | |
| self.conversation_history = [] | |
| logger.info("会話履歴を完全にクリアしました") | |
| # リトライ | |
| return self.ask_with_history(question, retry_count + 1) | |
| logger.error(f"履歴付き質問処理中にエラー: {e}") | |
| raise | |
| def clear_history(self): | |
| """ | |
| 対話履歴をクリア | |
| """ | |
| self.conversation_history = [] | |
| if self.verbose: | |
| logger.info("対話履歴をクリアしました") | |
| def get_history(self) -> str: | |
| """ | |
| 現在の対話履歴を取得(デバッグ用) | |
| """ | |
| return self._format_history_text() | |
| def search_similar(self, query: str, k: int = 5) -> List[Document]: | |
| """ | |
| 類似文書を検索 | |
| """ | |
| if not self.retriever: | |
| raise RuntimeError("リトリーバーが正しく初期化されていません") | |
| self.retriever.search_kwargs['k'] = k | |
| return self.retriever.invoke(query) | |
| def ask_question(question: str, persist_dir: str = "data/vector_store") -> None: | |
| """ | |
| 質問を投げて、回答と参照情報を表示(後方互換用関数) | |
| """ | |
| try: | |
| qa_chain = QAChain(persist_dir=persist_dir, verbose=True) | |
| answer, source_docs = qa_chain.ask(question) | |
| print(f"\n{'='*50}") | |
| print(f"質問: {question}") | |
| print(f"{'='*50}") | |
| print(f"\n回答:\n{answer}") | |
| print(f"\n参考にした文書 ({len(source_docs)}件):") | |
| for i, doc in enumerate(source_docs, 1): | |
| print(f"\n{i}. {doc.metadata.get('title', 'No Title')}") | |
| print(f" URL: {doc.metadata.get('source_url', 'N/A')}") | |
| print(f" チャンク: {doc.metadata.get('chunk_index', 'N/A')}") | |
| except Exception as e: | |
| print(f"エラー: {e}") | |