RAG_Test_System / app.py
CHUNYU0505's picture
Add application file
772ae76
raw
history blame
4.63 kB
# -------------------------------
# 1. 匯入套件
# -------------------------------
import os, glob, time
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatHuggingFaceHub
from langchain.chains import RetrievalQA
from docx import Document as DocxDocument
import gradio as gr
# -------------------------------
# 2. 設定路徑
# -------------------------------
txt_folder = "out_texts" # 放你的 .txt 檔
db_path = "faiss_db"
os.makedirs(db_path, exist_ok=True)
# -------------------------------
# 3. 建立 embeddings
# -------------------------------
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# -------------------------------
# 4. 建立或載入向量資料庫
# -------------------------------
if os.path.exists(os.path.join(db_path, "index.faiss")):
print("載入現有向量資料庫...")
db = FAISS.load_local(db_path, embeddings_model, allow_dangerous_deserialization=True)
else:
print("沒有資料庫,開始建立新向量資料庫...")
txt_files = glob.glob(f"{txt_folder}/*.txt")
docs = []
for filepath in txt_files:
with open(filepath, "r", encoding="utf-8") as f:
docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)}))
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
print("產生向量嵌入中...")
db = FAISS.from_documents(split_docs, embeddings_model)
db.save_local(db_path)
print("向量資料庫建立完成。")
# -------------------------------
# 5. Hugging Face 模型設定
# -------------------------------
HUGGINGFACE_API_TOKEN = os.getenv("HF_TOKEN") # 建議在 Spaces Secrets 設定
MODEL_DICT = {
"google/flan-t5-large": 512,
"tiiuae/falcon-7b-instruct": 512
}
MAX_HOURLY_REQUESTS = 50
request_count = 0
last_reset_time = time.time()
# -------------------------------
# 6. RAG 主函式
# -------------------------------
def rag_generate_hfapi(query, model_name, segments=5, max_words=1500):
global request_count, last_reset_time
if time.time() - last_reset_time > 3600:
request_count = 0
last_reset_time = time.time()
if request_count >= MAX_HOURLY_REQUESTS:
return f"本小時生成次數已達上限 ({MAX_HOURLY_REQUESTS}),請稍後再試。", None
llm = ChatHuggingFaceHub(
repo_id=model_name,
model_kwargs={"temperature": 0.7, "max_new_tokens": MODEL_DICT[model_name]},
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 5}),
return_source_documents=True
)
prompt = f"""請依據下列主題生成一篇文章:
主題:{query}
需求:
- 總共 {segments}
- 每段約 {max_words // segments}
- 總字數請控制在 {max_words} 字以內
- 請自動分段輸出
"""
try:
result = qa_chain({"query": prompt})
full_text = result["result"].strip()
if not full_text:
full_text = "(生成失敗,請改用其他模型或調整段落數)"
except Exception as e:
return f"(生成失敗:{str(e)})", None
request_count += 1
paragraphs = [p.strip() for p in full_text.split("\n") if p.strip()]
docx_file = "generated_article.docx"
doc = DocxDocument()
doc.add_heading(query, level=1)
for p in paragraphs:
doc.add_paragraph(p)
doc.save(docx_file)
return "\n\n".join(paragraphs), docx_file
# -------------------------------
# 7. Gradio 介面
# -------------------------------
iface = gr.Interface(
fn=rag_generate_hfapi,
inputs=[
gr.Textbox(lines=2, placeholder="請輸入文章主題"),
gr.Dropdown(list(MODEL_DICT.keys()), value="google/flan-t5-large", label="選擇模型"),
gr.Slider(minimum=1, maximum=10, value=5, step=1, label="段落數"),
gr.Slider(minimum=500, maximum=3000, value=1500, step=100, label="文章字數上限")
],
outputs=[
gr.Textbox(label="生成文章"),
gr.File(label="下載 DOCX")
],
title="佛教經論 RAG 系統 (Hugging Face API)",
description="使用 Hugging Face API 生成文章,可選大模型,分段生成並下載 DOCX,每小時生成次數有限制"
)
iface.launch()