| """ |
| Gradio + Groq API - 8種 RAG 策略 PDF 問答系統 |
| 需要安装: pip install gradio groq pypdf sentence-transformers numpy faiss-cpu scikit-learn |
| """ |
|
|
| import gradio as gr |
| from groq import Groq |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
| import faiss |
| from pypdf import PdfReader |
| import re |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from collections import Counter |
|
|
| class MultiStrategyRAG: |
| def __init__(self, api_key): |
| self.client = Groq(api_key=api_key) |
| self.embedding_model = SentenceTransformer( |
| 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' |
| ) |
| self.chunks = [] |
| self.embeddings = None |
| self.index = None |
| self.tfidf_vectorizer = None |
| self.tfidf_matrix = None |
| |
| def load_pdf(self, pdf_file): |
| """載入 PDF 檔案""" |
| try: |
| reader = PdfReader(pdf_file) |
| full_text = "" |
| |
| for page in reader.pages: |
| text = page.extract_text() |
| full_text += text + "\n" |
| |
| |
| self.chunks = self._split_text(full_text, chunk_size=800, overlap=150) |
| |
| |
| self.embeddings = self.embedding_model.encode( |
| self.chunks, |
| convert_to_numpy=True |
| ) |
| |
| |
| dimension = self.embeddings.shape[1] |
| self.index = faiss.IndexFlatL2(dimension) |
| self.index.add(self.embeddings.astype('float32')) |
| |
| |
| self.tfidf_vectorizer = TfidfVectorizer(max_features=1000) |
| self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(self.chunks) |
| |
| return f"✅ 成功載入 PDF!共 {len(reader.pages)} 頁,分割為 {len(self.chunks)} 個片段" |
| |
| except Exception as e: |
| return f"❌ 載入失敗: {str(e)}" |
| |
| def _split_text(self, text, chunk_size, overlap): |
| """分割文本""" |
| chunks = [] |
| start = 0 |
| text_length = len(text) |
| |
| while start < text_length: |
| end = start + chunk_size |
| chunk = text[start:end] |
| chunk = re.sub(r'\s+', ' ', chunk).strip() |
| |
| if chunk: |
| chunks.append(chunk) |
| start += chunk_size - overlap |
| |
| return chunks |
| |
| |
| |
| def strategy_1_basic_similarity(self, query, top_k=3): |
| """策略1: 基礎語意相似度搜尋""" |
| query_vector = self.embedding_model.encode([query]) |
| distances, indices = self.index.search(query_vector.astype('float32'), top_k) |
| return [self.chunks[idx] for idx in indices[0]] |
| |
| def strategy_2_tfidf(self, query, top_k=3): |
| """策略2: TF-IDF 關鍵詞搜尋""" |
| query_vector = self.tfidf_vectorizer.transform([query]) |
| similarities = (self.tfidf_matrix * query_vector.T).toarray().flatten() |
| top_indices = similarities.argsort()[-top_k:][::-1] |
| return [self.chunks[idx] for idx in top_indices] |
| |
| def strategy_3_hybrid(self, query, top_k=3): |
| """策略3: 混合搜尋 (語意 + TF-IDF)""" |
| |
| query_vector = self.embedding_model.encode([query]) |
| distances, sem_indices = self.index.search(query_vector.astype('float32'), top_k * 2) |
| |
| |
| query_tfidf = self.tfidf_vectorizer.transform([query]) |
| tfidf_scores = (self.tfidf_matrix * query_tfidf.T).toarray().flatten() |
| tfidf_indices = tfidf_scores.argsort()[-top_k * 2:][::-1] |
| |
| |
| combined = list(set(sem_indices[0].tolist() + tfidf_indices.tolist())) |
| return [self.chunks[idx] for idx in combined[:top_k]] |
| |
| def strategy_4_reranking(self, query, top_k=3): |
| """策略4: 重新排序(先檢索再用LLM重排)""" |
| |
| candidates = self.strategy_1_basic_similarity(query, top_k=top_k * 2) |
| |
| |
| reranked = [] |
| for chunk in candidates: |
| prompt = f"問題:{query}\n\n文本:{chunk[:200]}...\n\n這段文本與問題的相關度(0-10):" |
| |
| try: |
| response = self.client.chat.completions.create( |
| model="llama-3.1-8b-instant", |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=10, |
| temperature=0 |
| ) |
| score = response.choices[0].message.content.strip() |
| score = float(re.findall(r'\d+', score)[0]) if re.findall(r'\d+', score) else 0 |
| reranked.append((chunk, score)) |
| except: |
| reranked.append((chunk, 0)) |
| |
| reranked.sort(key=lambda x: x[1], reverse=True) |
| return [chunk for chunk, score in reranked[:top_k]] |
| |
| def strategy_5_multi_query(self, query, top_k=3): |
| """策略5: 多查詢擴展""" |
| |
| expansion_prompt = f"將以下問題改寫成3個相關但不同角度的問題,用換行分隔:\n{query}" |
| |
| try: |
| response = self.client.chat.completions.create( |
| model="llama-3.1-8b-instant", |
| messages=[{"role": "user", "content": expansion_prompt}], |
| max_tokens=200, |
| temperature=0.7 |
| ) |
| queries = [query] + response.choices[0].message.content.strip().split('\n')[:3] |
| except: |
| queries = [query] |
| |
| |
| all_chunks = [] |
| for q in queries: |
| chunks = self.strategy_1_basic_similarity(q, top_k=2) |
| all_chunks.extend(chunks) |
| |
| |
| unique_chunks = list(dict.fromkeys(all_chunks)) |
| return unique_chunks[:top_k] |
| |
| def strategy_6_contextual_compression(self, query, top_k=3): |
| """策略6: 上下文壓縮(提取最相關部分)""" |
| chunks = self.strategy_1_basic_similarity(query, top_k=top_k) |
| |
| compressed = [] |
| for chunk in chunks: |
| |
| compress_prompt = f"從以下文本中提取與問題「{query}」最相關的1-2句話:\n\n{chunk}" |
| |
| try: |
| response = self.client.chat.completions.create( |
| model="llama-3.1-8b-instant", |
| messages=[{"role": "user", "content": compress_prompt}], |
| max_tokens=150, |
| temperature=0 |
| ) |
| compressed.append(response.choices[0].message.content.strip()) |
| except: |
| compressed.append(chunk[:300]) |
| |
| return compressed |
| |
| def strategy_7_parent_child(self, query, top_k=3): |
| """策略7: 父子文檔(檢索小片段,返回大上下文)""" |
| |
| small_chunks = self._split_text(' '.join(self.chunks), chunk_size=300, overlap=50) |
| small_embeddings = self.embedding_model.encode(small_chunks, convert_to_numpy=True) |
| |
| small_index = faiss.IndexFlatL2(small_embeddings.shape[1]) |
| small_index.add(small_embeddings.astype('float32')) |
| |
| query_vector = self.embedding_model.encode([query]) |
| distances, indices = small_index.search(query_vector.astype('float32'), top_k) |
| |
| |
| results = [] |
| for idx in indices[0]: |
| |
| for big_chunk in self.chunks: |
| if small_chunks[idx] in big_chunk: |
| results.append(big_chunk) |
| break |
| |
| return list(dict.fromkeys(results))[:top_k] |
| |
| def strategy_8_hypothetical_answer(self, query, top_k=3): |
| """策略8: 假設性答案(HyDE - Hypothetical Document Embeddings)""" |
| |
| hyde_prompt = f"請對以下問題給出一個假設性的答案(即使不確定):\n{query}" |
| |
| try: |
| response = self.client.chat.completions.create( |
| model="llama-3.1-8b-instant", |
| messages=[{"role": "user", "content": hyde_prompt}], |
| max_tokens=200, |
| temperature=0.7 |
| ) |
| hypothetical_answer = response.choices[0].message.content |
| except: |
| hypothetical_answer = query |
| |
| |
| query_vector = self.embedding_model.encode([hypothetical_answer]) |
| distances, indices = self.index.search(query_vector.astype('float32'), top_k) |
| |
| return [self.chunks[idx] for idx in indices[0]] |
| |
| def generate_answer(self, query, strategy, top_k=3): |
| """生成答案""" |
| if not self.chunks: |
| return "❌ 請先上傳 PDF 檔案!", "" |
| |
| |
| strategies = { |
| "1. 基礎語意搜尋": self.strategy_1_basic_similarity, |
| "2. TF-IDF 關鍵詞": self.strategy_2_tfidf, |
| "3. 混合搜尋": self.strategy_3_hybrid, |
| "4. 重新排序": self.strategy_4_reranking, |
| "5. 多查詢擴展": self.strategy_5_multi_query, |
| "6. 上下文壓縮": self.strategy_6_contextual_compression, |
| "7. 父子文檔": self.strategy_7_parent_child, |
| "8. 假設性答案 (HyDE)": self.strategy_8_hypothetical_answer, |
| } |
| |
| retrieval_func = strategies.get(strategy, self.strategy_1_basic_similarity) |
| relevant_chunks = retrieval_func(query, top_k) |
| |
| |
| context = "\n\n---\n\n".join(relevant_chunks) |
| |
| |
| prompt = f"""請根據以下上下文回答問題。如果上下文中沒有相關資訊,請說明無法回答。 |
| |
| 上下文: |
| {context} |
| |
| 問題:{query} |
| |
| 請用繁體中文詳細回答:""" |
|
|
| try: |
| response = self.client.chat.completions.create( |
| model="llama-3.1-8b-instant", |
| messages=[ |
| {"role": "system", "content": "你是專業的文件分析助手。"}, |
| {"role": "user", "content": prompt} |
| ], |
| max_tokens=1024, |
| temperature=0.3 |
| ) |
| |
| answer = response.choices[0].message.content |
| source_info = f"📚 使用策略:{strategy}\n📄 檢索片段數:{len(relevant_chunks)}\n\n" + \ |
| "=" * 50 + "\n相關文本片段:\n" + "=" * 50 + "\n\n" + context |
| |
| return answer, source_info |
| |
| except Exception as e: |
| return f"❌ 生成答案失敗: {str(e)}", "" |
|
|
|
|
| |
| def create_interface(): |
| |
| API_KEY = "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" |
| rag = MultiStrategyRAG(api_key=API_KEY) |
| |
| def upload_pdf(file): |
| if file is None: |
| return "⚠️ 請選擇 PDF 檔案" |
| return rag.load_pdf(file.name) |
| |
| def ask_question(query, strategy, top_k): |
| return rag.generate_answer(query, strategy, top_k) |
| |
| |
| with gr.Blocks(title="🤖 多策略 RAG PDF 問答系統", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # 🤖 多策略 RAG PDF 問答系統 |
| |
| 採用 **8 種不同的 RAG 策略**,為您的 PDF 文件提供智能問答服務! |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### 📤 步驟 1: 上傳 PDF") |
| pdf_input = gr.File( |
| label="選擇 PDF 檔案", |
| file_types=[".pdf"] |
| ) |
| upload_btn = gr.Button("🚀 載入文件", variant="primary") |
| upload_status = gr.Textbox(label="載入狀態", interactive=False) |
| |
| gr.Markdown("### ⚙️ 步驟 2: 選擇 RAG 策略") |
| strategy_dropdown = gr.Dropdown( |
| choices=[ |
| "1. 基礎語意搜尋", |
| "2. TF-IDF 關鍵詞", |
| "3. 混合搜尋", |
| "4. 重新排序", |
| "5. 多查詢擴展", |
| "6. 上下文壓縮", |
| "7. 父子文檔", |
| "8. 假設性答案 (HyDE)" |
| ], |
| value="1. 基礎語意搜尋", |
| label="RAG 策略" |
| ) |
| |
| top_k_slider = gr.Slider( |
| minimum=1, |
| maximum=10, |
| value=3, |
| step=1, |
| label="檢索片段數量 (Top-K)" |
| ) |
| |
| gr.Markdown(""" |
| ### 📖 策略說明 |
| |
| 1. **基礎語意搜尋**: 使用向量相似度 |
| 2. **TF-IDF 關鍵詞**: 基於詞頻統計 |
| 3. **混合搜尋**: 結合語意與關鍵詞 |
| 4. **重新排序**: LLM 重新評分 |
| 5. **多查詢擴展**: 生成多個相關問題 |
| 6. **上下文壓縮**: 提取最相關部分 |
| 7. **父子文檔**: 小片段檢索大上下文 |
| 8. **假設性答案**: 先生成答案再搜尋 |
| """) |
| |
| with gr.Column(scale=2): |
| gr.Markdown("### 💬 步驟 3: 提問") |
| question_input = gr.Textbox( |
| label="輸入您的問題", |
| placeholder="例如:這份文件的主要內容是什麼?", |
| lines=3 |
| ) |
| ask_btn = gr.Button("🔍 提問", variant="primary", size="lg") |
| |
| gr.Markdown("### 💡 答案") |
| answer_output = gr.Textbox( |
| label="AI 回答", |
| lines=10, |
| interactive=False |
| ) |
| |
| with gr.Accordion("📚 查看檢索到的文本片段", open=False): |
| source_output = gr.Textbox( |
| label="相關來源", |
| lines=15, |
| interactive=False |
| ) |
| |
| |
| upload_btn.click( |
| fn=upload_pdf, |
| inputs=[pdf_input], |
| outputs=[upload_status] |
| ) |
| |
| ask_btn.click( |
| fn=ask_question, |
| inputs=[question_input, strategy_dropdown, top_k_slider], |
| outputs=[answer_output, source_output] |
| ) |
| |
| |
| gr.Examples( |
| examples=[ |
| ["這份文件的主要內容是什麼?"], |
| ["文件中提到哪些重要概念?"], |
| ["有哪些關鍵數據或統計資料?"], |
| ["文件的結論是什麼?"] |
| ], |
| inputs=question_input |
| ) |
| |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = create_interface() |
| demo.launch( |
| share=True, |
| server_name="0.0.0.0", |
| server_port=7860 |
| ) |