8KindsRAG / app.py
cjian2025's picture
Update app.py
91d1212 verified
"""
Gradio + Groq API - 8種 RAG 策略 PDF 問答系統
需要安装: pip install gradio groq pypdf sentence-transformers numpy faiss-cpu scikit-learn
"""
import gradio as gr
from groq import Groq
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
from pypdf import PdfReader
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter
class MultiStrategyRAG:
def __init__(self, api_key):
self.client = Groq(api_key=api_key)
self.embedding_model = SentenceTransformer(
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
)
self.chunks = []
self.embeddings = None
self.index = None
self.tfidf_vectorizer = None
self.tfidf_matrix = None
def load_pdf(self, pdf_file):
"""載入 PDF 檔案"""
try:
reader = PdfReader(pdf_file)
full_text = ""
for page in reader.pages:
text = page.extract_text()
full_text += text + "\n"
# 分割文本
self.chunks = self._split_text(full_text, chunk_size=800, overlap=150)
# 生成嵌入向量
self.embeddings = self.embedding_model.encode(
self.chunks,
convert_to_numpy=True
)
# 建立 FAISS 索引
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(self.embeddings.astype('float32'))
# 建立 TF-IDF 索引
self.tfidf_vectorizer = TfidfVectorizer(max_features=1000)
self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(self.chunks)
return f"✅ 成功載入 PDF!共 {len(reader.pages)} 頁,分割為 {len(self.chunks)} 個片段"
except Exception as e:
return f"❌ 載入失敗: {str(e)}"
def _split_text(self, text, chunk_size, overlap):
"""分割文本"""
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + chunk_size
chunk = text[start:end]
chunk = re.sub(r'\s+', ' ', chunk).strip()
if chunk:
chunks.append(chunk)
start += chunk_size - overlap
return chunks
# ==================== 8種 RAG 策略 ====================
def strategy_1_basic_similarity(self, query, top_k=3):
"""策略1: 基礎語意相似度搜尋"""
query_vector = self.embedding_model.encode([query])
distances, indices = self.index.search(query_vector.astype('float32'), top_k)
return [self.chunks[idx] for idx in indices[0]]
def strategy_2_tfidf(self, query, top_k=3):
"""策略2: TF-IDF 關鍵詞搜尋"""
query_vector = self.tfidf_vectorizer.transform([query])
similarities = (self.tfidf_matrix * query_vector.T).toarray().flatten()
top_indices = similarities.argsort()[-top_k:][::-1]
return [self.chunks[idx] for idx in top_indices]
def strategy_3_hybrid(self, query, top_k=3):
"""策略3: 混合搜尋 (語意 + TF-IDF)"""
# 語意搜尋
query_vector = self.embedding_model.encode([query])
distances, sem_indices = self.index.search(query_vector.astype('float32'), top_k * 2)
# TF-IDF 搜尋
query_tfidf = self.tfidf_vectorizer.transform([query])
tfidf_scores = (self.tfidf_matrix * query_tfidf.T).toarray().flatten()
tfidf_indices = tfidf_scores.argsort()[-top_k * 2:][::-1]
# 合併結果(去重)
combined = list(set(sem_indices[0].tolist() + tfidf_indices.tolist()))
return [self.chunks[idx] for idx in combined[:top_k]]
def strategy_4_reranking(self, query, top_k=3):
"""策略4: 重新排序(先檢索再用LLM重排)"""
# 先檢索較多候選
candidates = self.strategy_1_basic_similarity(query, top_k=top_k * 2)
# 使用 LLM 評分重排(簡化版:用相關度評分)
reranked = []
for chunk in candidates:
prompt = f"問題:{query}\n\n文本:{chunk[:200]}...\n\n這段文本與問題的相關度(0-10):"
try:
response = self.client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": prompt}],
max_tokens=10,
temperature=0
)
score = response.choices[0].message.content.strip()
score = float(re.findall(r'\d+', score)[0]) if re.findall(r'\d+', score) else 0
reranked.append((chunk, score))
except:
reranked.append((chunk, 0))
reranked.sort(key=lambda x: x[1], reverse=True)
return [chunk for chunk, score in reranked[:top_k]]
def strategy_5_multi_query(self, query, top_k=3):
"""策略5: 多查詢擴展"""
# 生成相關查詢
expansion_prompt = f"將以下問題改寫成3個相關但不同角度的問題,用換行分隔:\n{query}"
try:
response = self.client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": expansion_prompt}],
max_tokens=200,
temperature=0.7
)
queries = [query] + response.choices[0].message.content.strip().split('\n')[:3]
except:
queries = [query]
# 對每個查詢搜尋
all_chunks = []
for q in queries:
chunks = self.strategy_1_basic_similarity(q, top_k=2)
all_chunks.extend(chunks)
# 去重並保留 top_k
unique_chunks = list(dict.fromkeys(all_chunks))
return unique_chunks[:top_k]
def strategy_6_contextual_compression(self, query, top_k=3):
"""策略6: 上下文壓縮(提取最相關部分)"""
chunks = self.strategy_1_basic_similarity(query, top_k=top_k)
compressed = []
for chunk in chunks:
# 使用 LLM 提取與問題最相關的部分
compress_prompt = f"從以下文本中提取與問題「{query}」最相關的1-2句話:\n\n{chunk}"
try:
response = self.client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": compress_prompt}],
max_tokens=150,
temperature=0
)
compressed.append(response.choices[0].message.content.strip())
except:
compressed.append(chunk[:300])
return compressed
def strategy_7_parent_child(self, query, top_k=3):
"""策略7: 父子文檔(檢索小片段,返回大上下文)"""
# 檢索小片段
small_chunks = self._split_text(' '.join(self.chunks), chunk_size=300, overlap=50)
small_embeddings = self.embedding_model.encode(small_chunks, convert_to_numpy=True)
small_index = faiss.IndexFlatL2(small_embeddings.shape[1])
small_index.add(small_embeddings.astype('float32'))
query_vector = self.embedding_model.encode([query])
distances, indices = small_index.search(query_vector.astype('float32'), top_k)
# 返回包含該小片段的較大上下文
results = []
for idx in indices[0]:
# 找到對應的原始大片段
for big_chunk in self.chunks:
if small_chunks[idx] in big_chunk:
results.append(big_chunk)
break
return list(dict.fromkeys(results))[:top_k]
def strategy_8_hypothetical_answer(self, query, top_k=3):
"""策略8: 假設性答案(HyDE - Hypothetical Document Embeddings)"""
# 生成假設性答案
hyde_prompt = f"請對以下問題給出一個假設性的答案(即使不確定):\n{query}"
try:
response = self.client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": hyde_prompt}],
max_tokens=200,
temperature=0.7
)
hypothetical_answer = response.choices[0].message.content
except:
hypothetical_answer = query
# 用假設答案搜尋
query_vector = self.embedding_model.encode([hypothetical_answer])
distances, indices = self.index.search(query_vector.astype('float32'), top_k)
return [self.chunks[idx] for idx in indices[0]]
def generate_answer(self, query, strategy, top_k=3):
"""生成答案"""
if not self.chunks:
return "❌ 請先上傳 PDF 檔案!", ""
# 根據策略選擇檢索方法
strategies = {
"1. 基礎語意搜尋": self.strategy_1_basic_similarity,
"2. TF-IDF 關鍵詞": self.strategy_2_tfidf,
"3. 混合搜尋": self.strategy_3_hybrid,
"4. 重新排序": self.strategy_4_reranking,
"5. 多查詢擴展": self.strategy_5_multi_query,
"6. 上下文壓縮": self.strategy_6_contextual_compression,
"7. 父子文檔": self.strategy_7_parent_child,
"8. 假設性答案 (HyDE)": self.strategy_8_hypothetical_answer,
}
retrieval_func = strategies.get(strategy, self.strategy_1_basic_similarity)
relevant_chunks = retrieval_func(query, top_k)
# 組合上下文
context = "\n\n---\n\n".join(relevant_chunks)
# 生成答案
prompt = f"""請根據以下上下文回答問題。如果上下文中沒有相關資訊,請說明無法回答。
上下文:
{context}
問題:{query}
請用繁體中文詳細回答:"""
try:
response = self.client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[
{"role": "system", "content": "你是專業的文件分析助手。"},
{"role": "user", "content": prompt}
],
max_tokens=1024,
temperature=0.3
)
answer = response.choices[0].message.content
source_info = f"📚 使用策略:{strategy}\n📄 檢索片段數:{len(relevant_chunks)}\n\n" + \
"=" * 50 + "\n相關文本片段:\n" + "=" * 50 + "\n\n" + context
return answer, source_info
except Exception as e:
return f"❌ 生成答案失敗: {str(e)}", ""
# 建立 Gradio 介面
def create_interface():
# 初始化 RAG 系統
API_KEY = "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"
rag = MultiStrategyRAG(api_key=API_KEY)
def upload_pdf(file):
if file is None:
return "⚠️ 請選擇 PDF 檔案"
return rag.load_pdf(file.name)
def ask_question(query, strategy, top_k):
return rag.generate_answer(query, strategy, top_k)
# 建立介面
with gr.Blocks(title="🤖 多策略 RAG PDF 問答系統", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🤖 多策略 RAG PDF 問答系統
採用 **8 種不同的 RAG 策略**,為您的 PDF 文件提供智能問答服務!
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📤 步驟 1: 上傳 PDF")
pdf_input = gr.File(
label="選擇 PDF 檔案",
file_types=[".pdf"]
)
upload_btn = gr.Button("🚀 載入文件", variant="primary")
upload_status = gr.Textbox(label="載入狀態", interactive=False)
gr.Markdown("### ⚙️ 步驟 2: 選擇 RAG 策略")
strategy_dropdown = gr.Dropdown(
choices=[
"1. 基礎語意搜尋",
"2. TF-IDF 關鍵詞",
"3. 混合搜尋",
"4. 重新排序",
"5. 多查詢擴展",
"6. 上下文壓縮",
"7. 父子文檔",
"8. 假設性答案 (HyDE)"
],
value="1. 基礎語意搜尋",
label="RAG 策略"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="檢索片段數量 (Top-K)"
)
gr.Markdown("""
### 📖 策略說明
1. **基礎語意搜尋**: 使用向量相似度
2. **TF-IDF 關鍵詞**: 基於詞頻統計
3. **混合搜尋**: 結合語意與關鍵詞
4. **重新排序**: LLM 重新評分
5. **多查詢擴展**: 生成多個相關問題
6. **上下文壓縮**: 提取最相關部分
7. **父子文檔**: 小片段檢索大上下文
8. **假設性答案**: 先生成答案再搜尋
""")
with gr.Column(scale=2):
gr.Markdown("### 💬 步驟 3: 提問")
question_input = gr.Textbox(
label="輸入您的問題",
placeholder="例如:這份文件的主要內容是什麼?",
lines=3
)
ask_btn = gr.Button("🔍 提問", variant="primary", size="lg")
gr.Markdown("### 💡 答案")
answer_output = gr.Textbox(
label="AI 回答",
lines=10,
interactive=False
)
with gr.Accordion("📚 查看檢索到的文本片段", open=False):
source_output = gr.Textbox(
label="相關來源",
lines=15,
interactive=False
)
# 綁定事件
upload_btn.click(
fn=upload_pdf,
inputs=[pdf_input],
outputs=[upload_status]
)
ask_btn.click(
fn=ask_question,
inputs=[question_input, strategy_dropdown, top_k_slider],
outputs=[answer_output, source_output]
)
# 範例問題
gr.Examples(
examples=[
["這份文件的主要內容是什麼?"],
["文件中提到哪些重要概念?"],
["有哪些關鍵數據或統計資料?"],
["文件的結論是什麼?"]
],
inputs=question_input
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=True, # 設為 True 可生成公開連結
server_name="0.0.0.0",
server_port=7860
)