File size: 4,828 Bytes
d051231
c4310e4
d051231
 
 
 
 
c4310e4
d051231
 
 
1740855
c4310e4
76b0768
6f0de2e
 
8be1b46
6f0de2e
 
d051231
6f0de2e
033a019
6f0de2e
a23ab36
6f0de2e
2aa3d8b
d0ba755
9c1b3ba
d051231
9c1b3ba
 
 
d051231
 
 
 
 
 
c4310e4
 
 
 
d051231
 
c4310e4
d051231
 
c4310e4
d051231
 
 
 
 
 
 
 
 
 
c4310e4
 
d051231
 
 
 
 
 
c4310e4
d0ba755
e1aabb3
d0ba755
 
 
e1aabb3
d0ba755
 
6f0de2e
e1aabb3
6f0de2e
 
e1aabb3
80fe36a
c4310e4
 
e1aabb3
 
 
6f0de2e
 
e1aabb3
 
 
c4310e4
6f0de2e
a23ab36
132ef2d
6f0de2e
e1aabb3
d051231
 
6f0de2e
 
e1aabb3
6f0de2e
 
e1aabb3
80fe36a
d0ba755
c4310e4
d0ba755
c6f8f84
94b2916
c4310e4
f90da5a
c6f8f84
e1aabb3
fb13185
c6f8f84
fb13185
e1aabb3
d051231
d0ba755
f90da5a
255d19f
e1aabb3
a23ab36
d051231
255d19f
f90da5a
d0ba755
a6c8097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py
import os
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from huggingface_hub import login, snapshot_download
import gradio as gr

# -------------------------------
# 0. 向量資料庫載入
# -------------------------------
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)

DB_PATH = "./faiss_db"
if os.path.exists(DB_PATH):
    print("✅ 載入現有向量資料庫...")
    db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
    raise ValueError("❌ 沒找到 faiss_db,請先建立向量資料庫")

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})

# -------------------------------
# 1. 中文模型(T5 Pegasus)
# -------------------------------
MODEL_NAME = "imxly/t5-pegasus-small"



HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")

LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
if not os.path.exists(LOCAL_MODEL_DIR):
    print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
    snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)

tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)

generator = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=-1  # CPU
)

def call_local_inference(prompt, max_new_tokens=256):
    try:
        outputs = generator(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=False,   # 用摘要模型 → 不建議隨機取樣
            temperature=0.7
        )
        return outputs[0]["generated_text"]
    except Exception as e:
        return f"(生成失敗:{e})"

# -------------------------------
# 2. 基於 RAG 的文章生成
# -------------------------------
def generate_article_rag_only(query, segments=3):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)
    doc.save(docx_file)

    all_text = []

    # 🔍 RAG 檢索
    retrieved_docs = retriever.get_relevant_documents(query)
    context_texts = [d.page_content for d in retrieved_docs]
    full_context = "\n".join(context_texts)

    # 切分 context,避免太長
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    chunks = splitter.split_text(full_context)

    for i, chunk in enumerate(chunks[:segments]):
        progress_text = f"⏳ 正在生成第 {i+1}/{segments} 段..."
        prompt = (
            f"以下是唯一可用的參考內容:\n{chunk}\n\n"
            f"請基於這些內容,寫一段約150-200字的中文文章,"
            f"主題:{query}。\n"
            f"⚠️ 僅能使用參考內容,不可加入外部知識。"
        )
        paragraph = call_local_inference(prompt)
        all_text.append(paragraph)

        # 即時寫入 DOCX
        doc = DocxDocument(docx_file)
        doc.add_paragraph(f"第{i+1}段:\n{paragraph}")
        doc.save(docx_file)

        yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}", full_context, progress_text

    final_progress = f"✅ 已完成全部 {segments} 段生成!"
    yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", full_context, final_progress

# -------------------------------
# 3. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
    gr.Markdown("只基於 faiss_db 內容生成中文文章。")

    query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
    segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
    output_text = gr.Textbox(label="生成文章")
    output_file = gr.File(label="下載 DOCX")
    model_used_text = gr.Textbox(label="實際使用模型", interactive=False)
    context_text = gr.Textbox(label="檢索到的內容", interactive=False, lines=6)
    progress_text = gr.Textbox(label="生成進度", interactive=False)

    btn = gr.Button("生成文章")
    btn.click(
        generate_article_rag_only,
        inputs=[query_input, segments_input],
        outputs=[output_text, output_file, model_used_text, context_text, progress_text]
    )

if __name__ == "__main__":
    demo.launch()