File size: 14,288 Bytes
fb05e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069f493
fb05e78
dc16323
fb05e78
 
069f493
fb05e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069f493
fb05e78
 
069f493
fb05e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""
RAGベースの質問応答用QA Chainモジュール
"""

import os
import logging
from pathlib import Path
from typing import List, Optional, Tuple

from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

from .prompt import QA_TEMPLATE, CHARACTER_TEMPLATE, QA_TEMPLATE_WITH_HISTORY

logger = logging.getLogger(__name__)


class QAChain:
    """RAGベースの質問応答用QA Chain"""

    def __init__(
        self,
        persist_dir: str = "data/vector_store",
        model_name: str = "text-embedding-3-small",
        k: int = 10,
        verbose: bool = False,
        llm_model: str = "gpt-4o-mini",
        llm_temperature: float = 0.3,
        llm_max_tokens: Optional[int] = None,
        max_history_turns: int = 3,
        max_history_chars: int = 10000
    ):
        """
        永続化されたベクトルストアを使ってQA Chainを初期化
        
        Args:
            persist_dir: ベクトルストアの保存ディレクトリ
            model_name: 埋め込みモデル名
            k: 検索する文書の数
            verbose: 詳細ログの出力
            llm_model: 使用するLLMモデル名
            llm_temperature: LLMの温度パラメーター(0-2)
            llm_max_tokens: 最大トークン数(Noneで自動)
            max_history_turns: 保持する最大会話ターン数(デフォルト: 10)
            max_history_chars: 履歴の最大文字数(デフォルト: 10000)
        """
        self.persist_dir = persist_dir
        self.model_name = model_name
        self.k = k
        self.verbose = verbose
        self.llm_model = llm_model
        self.llm_temperature = llm_temperature
        self.llm_max_tokens = llm_max_tokens
        self.max_history_turns = max_history_turns
        self.max_history_chars = max_history_chars
        self.conversation_history = []  # 配列形式で管理

        # 環境変数の読み込み(より堅牢な実装)
        try:
            load_dotenv(dotenv_path=".env")
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("OPENAI_API_KEYが.envファイルに見つかりません")
        except FileNotFoundError:
            if self.verbose:
                logger.warning(".envファイルが見つかりません。環境変数から読み込みを試みます。")
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("OPENAI_API_KEYが環境変数に設定されていません")
        except Exception as e:
            logger.error(f"環境変数の読み込み中にエラーが発生しました: {e}")
            raise

        self.db = None
        self.retriever = None
        self.rag_chain = None
        self.model_with_history = None
        self._setup_chain()

    def _load_vector_store(self) -> Chroma:
        """永続化されたベクトルストアを読み込む"""
        persist_path = Path(self.persist_dir)
        if not persist_path.exists():
            raise FileNotFoundError(
                f"ベクトルストアが見つかりません: {self.persist_dir}\n"
                "まず次のコマンドでベクトルストアを作成してください: python -m src.cli vector build"
            )
        embeddings = OpenAIEmbeddings(model=self.model_name)
        db = Chroma(
            persist_directory=self.persist_dir,
            embedding_function=embeddings
        )
        if self.verbose:
            logger.info(f"ベクトルストアを{self.persist_dir}から読み込みました")
        return db

    @staticmethod
    def _format_docs_with_metadata(docs: List[Document]) -> str:
        """文書をメタデータ付きでコンテキスト用に整形"""
        return '\n\n'.join(
            f"[出典: {doc.metadata.get('title', 'Unknown Title')} - チャンク {doc.metadata.get('chunk_index', 'N/A')}]"
            f"\nURL: {doc.metadata.get('source_url', 'Unknown Source')}"
            f"\n内容: {doc.page_content}\n---"
            for doc in docs
        )

    def _setup_chain(self):
        """RAGチェーン全体をセットアップ"""
        try:
            self.db = self._load_vector_store()
            self.retriever = self.db.as_retriever(
                search_type='similarity',
                search_kwargs={'k': self.k}
            )
            
            # LLMのパラメーターを動的に構成
            llm_params = {
                'model': self.llm_model,
            }
            if self.llm_model == 'gpt-5-nano':
                llm_params['temperature'] = 1.0
                llm_params['reasoning_effort'] = 'minimal'
                llm_params['verbosity'] = 'low'
            else:
                # その他のモデルでは指定された temperature を使用
                if self.llm_temperature != 0.3:
                    llm_params['temperature'] = self.llm_temperature
            
            if self.llm_max_tokens is not None:
                llm_params['max_tokens'] = self.llm_max_tokens
            
            model = ChatOpenAI(**llm_params)
            format_docs = RunnableLambda(self._format_docs_with_metadata)
            qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE)
            
            # 通常のRAGチェーン
            self.rag_chain = {
                'context': self.retriever | format_docs,
                'question': RunnablePassthrough(),
            } | qa_prompt | model | StrOutputParser()

            # モデルを保存
            self.model = model
            self.model_with_history = model
            self.format_docs_with_history = format_docs

            if self.verbose:
                logger.info("永続化ベクトルストアを使ってRAGチェーンを作成しました!")
        except Exception as e:
            logger.error(f"QAチェーンのセットアップ中にエラー: {e}")
            raise

    def ask(self, question: str) -> Tuple[str, List[Document]]:
        """
        質問を投げて、回答と参照文書を取得
        """
        if not self.model:
            raise RuntimeError("モデルが正しく初期化されていません")
        try:
            # 文書を検索して整形
            source_docs = self.retriever.invoke(question)
            context = self._format_docs_with_metadata(source_docs)
            
            # URLリストを作成(重複を除去)
            source_urls = []
            for doc in source_docs:
                url = doc.metadata.get('source_url', '')
                if url and url not in source_urls:
                    source_urls.append(url)
            urls_text = '\n'.join(f"- {url}" for url in source_urls)
            
            # プロンプトを構築して実行
            prompt_input = {
                'context': context,
                'question': question,
                'source_urls': urls_text  
            }
            qa_prompt = ChatPromptTemplate.from_template(QA_TEMPLATE)
            
            # プロンプトテンプレートを適用して回答を生成
            messages = qa_prompt.invoke(prompt_input)
            answer = self.model.invoke(messages).content
            
            return answer, source_docs
        except Exception as e:
            logger.error(f"質問処理中にエラー: {e}")
            raise

    def _manage_history_window(self):
        """
        Sliding Windowを使用して履歴を管理
        最大ターン数と最大文字数の両方を考慮
        """
        # ターン数の制限
        if len(self.conversation_history) > self.max_history_turns:
            self.conversation_history = self.conversation_history[-self.max_history_turns:]
        
        # 文字数の制限(古い会話から削除)
        total_chars = sum(len(turn) for turn in self.conversation_history)
        while total_chars > self.max_history_chars and len(self.conversation_history) > 1:
            removed = self.conversation_history.pop(0)
            total_chars -= len(removed)
            if self.verbose:
                logger.info(f"履歴が制限を超えたため、古い会話を削除しました(削除文字数: {len(removed)})")
    
    def _format_history_text(self) -> str:
        """
        会話履歴配列を文字列に整形
        """
        if not self.conversation_history:
            return "まだ会話履歴はありません"
        return "\n".join(self.conversation_history)

    def ask_with_history(self, question: str, retry_count: int = 0) -> Tuple[str, List[Document]]:
        """
        対話履歴を考慮した質問応答(Sliding Window機能付き)
        
        Args:
            question: 質問内容
            retry_count: リトライ回数(内部使用)
        """
        if not self.model_with_history:
            raise RuntimeError("履歴付きモデルが正しく初期化されていません")
        
        try:
            # 履歴を文字列形式に変換
            history_text = self._format_history_text()
            
            # 文書を検索して整形
            source_docs = self.retriever.invoke(question)
            context = self._format_docs_with_metadata(source_docs)
            
            # URLリストを作成(重複を除去)
            source_urls = []
            for doc in source_docs:
                url = doc.metadata.get('source_url', '')
                if url and url not in source_urls:
                    source_urls.append(url)
            urls_text = '\n'.join(f"- {url}" for url in source_urls)
            
            # プロンプトを手動で構築して実行
            prompt_input = {
                'context': context,
                'question': question,
                'conversation_history': history_text,
                'source_urls': urls_text  
            }
            qa_prompt_with_history = ChatPromptTemplate.from_messages([
                ("user", "context: {context}"),
                ("system", CHARACTER_TEMPLATE),
                ("system", QA_TEMPLATE_WITH_HISTORY),
                ("user", "質問: {question}")
            ])
            
            # プロンプトテンプレートを適用して回答を生成
            messages = qa_prompt_with_history.invoke(prompt_input)
            answer = self.model_with_history.invoke(messages).content
            
            # 新しい会話を履歴に追加
            new_turn = f"ユーザー: {question}\nKurageSan®: {answer}"
            self.conversation_history.append(new_turn)
            
            # Sliding Windowを適用
            self._manage_history_window()
            
            if self.verbose:
                total_chars = sum(len(turn) for turn in self.conversation_history)
                logger.info(f"履歴付き質問処理完了。ターン数: {len(self.conversation_history)}, 合計文字数: {total_chars}")
            
            return answer, source_docs
            
        except Exception as e:
            # トークン制限エラーの処理
            error_message = str(e).lower()
            if retry_count < 2 and ('maximum context length' in error_message or 
                                  'token' in error_message and 'limit' in error_message):
                logger.warning(f"トークン制限エラーが発生しました。履歴を削減して再試行します(試行回数: {retry_count + 1})")
                
                # 履歴を半分に削減
                if len(self.conversation_history) > 1:
                    old_size = len(self.conversation_history)
                    self.conversation_history = self.conversation_history[old_size//2:]
                    logger.info(f"会話履歴を削減: {old_size} -> {len(self.conversation_history)} ターン")
                else:
                    # 履歴が1つ以下の場合はクリア
                    self.conversation_history = []
                    logger.info("会話履歴を完全にクリアしました")
                
                # リトライ
                return self.ask_with_history(question, retry_count + 1)
            
            logger.error(f"履歴付き質問処理中にエラー: {e}")
            raise
    
    def clear_history(self):
        """
        対話履歴をクリア
        """
        self.conversation_history = []
        if self.verbose:
            logger.info("対話履歴をクリアしました")
    
    def get_history(self) -> str:
        """
        現在の対話履歴を取得(デバッグ用)
        """
        return self._format_history_text()

    def search_similar(self, query: str, k: int = 5) -> List[Document]:
        """
        類似文書を検索
        """
        if not self.retriever:
            raise RuntimeError("リトリーバーが正しく初期化されていません")
        self.retriever.search_kwargs['k'] = k
        return self.retriever.invoke(query)


def ask_question(question: str, persist_dir: str = "data/vector_store") -> None:
    """
    質問を投げて、回答と参照情報を表示(後方互換用関数)
    """
    try:
        qa_chain = QAChain(persist_dir=persist_dir, verbose=True)
        answer, source_docs = qa_chain.ask(question)

        print(f"\n{'='*50}")
        print(f"質問: {question}")
        print(f"{'='*50}")
        print(f"\n回答:\n{answer}")

        print(f"\n参考にした文書 ({len(source_docs)}件):")
        for i, doc in enumerate(source_docs, 1):
            print(f"\n{i}. {doc.metadata.get('title', 'No Title')}")
            print(f"   URL: {doc.metadata.get('source_url', 'N/A')}")
            print(f"   チャンク: {doc.metadata.get('chunk_index', 'N/A')}")
    except Exception as e:
        print(f"エラー: {e}")