Spaces:
Sleeping
Sleeping
File size: 4,828 Bytes
d051231 c4310e4 d051231 c4310e4 d051231 1740855 c4310e4 76b0768 6f0de2e 8be1b46 6f0de2e d051231 6f0de2e 033a019 6f0de2e a23ab36 6f0de2e 2aa3d8b d0ba755 9c1b3ba d051231 9c1b3ba d051231 c4310e4 d051231 c4310e4 d051231 c4310e4 d051231 c4310e4 d051231 c4310e4 d0ba755 e1aabb3 d0ba755 e1aabb3 d0ba755 6f0de2e e1aabb3 6f0de2e e1aabb3 80fe36a c4310e4 e1aabb3 6f0de2e e1aabb3 c4310e4 6f0de2e a23ab36 132ef2d 6f0de2e e1aabb3 d051231 6f0de2e e1aabb3 6f0de2e e1aabb3 80fe36a d0ba755 c4310e4 d0ba755 c6f8f84 94b2916 c4310e4 f90da5a c6f8f84 e1aabb3 fb13185 c6f8f84 fb13185 e1aabb3 d051231 d0ba755 f90da5a 255d19f e1aabb3 a23ab36 d051231 255d19f f90da5a d0ba755 a6c8097 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# app.py
import os
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from huggingface_hub import login, snapshot_download
import gradio as gr
# -------------------------------
# 0. 向量資料庫載入
# -------------------------------
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
DB_PATH = "./faiss_db"
if os.path.exists(DB_PATH):
print("✅ 載入現有向量資料庫...")
db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
raise ValueError("❌ 沒找到 faiss_db,請先建立向量資料庫")
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
# -------------------------------
# 1. 中文模型(T5 Pegasus)
# -------------------------------
MODEL_NAME = "imxly/t5-pegasus-small"
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
if not os.path.exists(LOCAL_MODEL_DIR):
print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=-1 # CPU
)
def call_local_inference(prompt, max_new_tokens=256):
try:
outputs = generator(
prompt,
max_new_tokens=max_new_tokens,
do_sample=False, # 用摘要模型 → 不建議隨機取樣
temperature=0.7
)
return outputs[0]["generated_text"]
except Exception as e:
return f"(生成失敗:{e})"
# -------------------------------
# 2. 基於 RAG 的文章生成
# -------------------------------
def generate_article_rag_only(query, segments=3):
docx_file = "/tmp/generated_article.docx"
doc = DocxDocument()
doc.add_heading(query, level=1)
doc.save(docx_file)
all_text = []
# 🔍 RAG 檢索
retrieved_docs = retriever.get_relevant_documents(query)
context_texts = [d.page_content for d in retrieved_docs]
full_context = "\n".join(context_texts)
# 切分 context,避免太長
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_text(full_context)
for i, chunk in enumerate(chunks[:segments]):
progress_text = f"⏳ 正在生成第 {i+1}/{segments} 段..."
prompt = (
f"以下是唯一可用的參考內容:\n{chunk}\n\n"
f"請基於這些內容,寫一段約150-200字的中文文章,"
f"主題:{query}。\n"
f"⚠️ 僅能使用參考內容,不可加入外部知識。"
)
paragraph = call_local_inference(prompt)
all_text.append(paragraph)
# 即時寫入 DOCX
doc = DocxDocument(docx_file)
doc.add_paragraph(f"第{i+1}段:\n{paragraph}")
doc.save(docx_file)
yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}", full_context, progress_text
final_progress = f"✅ 已完成全部 {segments} 段生成!"
yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", full_context, final_progress
# -------------------------------
# 3. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
gr.Markdown("只基於 faiss_db 內容生成中文文章。")
query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
output_text = gr.Textbox(label="生成文章")
output_file = gr.File(label="下載 DOCX")
model_used_text = gr.Textbox(label="實際使用模型", interactive=False)
context_text = gr.Textbox(label="檢索到的內容", interactive=False, lines=6)
progress_text = gr.Textbox(label="生成進度", interactive=False)
btn = gr.Button("生成文章")
btn.click(
generate_article_rag_only,
inputs=[query_input, segments_input],
outputs=[output_text, output_file, model_used_text, context_text, progress_text]
)
if __name__ == "__main__":
demo.launch()
|