File size: 5,371 Bytes
058eba2
8f7234f
d051231
 
 
 
 
c4310e4
d051231
 
 
1740855
058eba2
d051231
058eba2
 
9c1b3ba
d051231
 
 
 
 
058eba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d051231
06f5c87
 
 
 
 
 
 
058eba2
06f5c87
058eba2
 
 
 
 
 
 
 
 
 
d051231
 
058eba2
d051231
 
058eba2
d051231
 
 
 
 
 
 
06f5c87
c4310e4
d051231
 
 
 
 
 
058eba2
06f5c87
 
 
 
 
 
 
 
 
058eba2
 
8f7234f
058eba2
06f5c87
 
 
d0ba755
8f7234f
d0ba755
 
 
 
 
6f0de2e
058eba2
 
 
80fe36a
06f5c87
6f0de2e
8f7234f
06f5c87
 
6f0de2e
a23ab36
132ef2d
06f5c87
6f0de2e
058eba2
6f0de2e
06f5c87
058eba2
80fe36a
d0ba755
06f5c87
d0ba755
c6f8f84
06f5c87
058eba2
 
c6f8f84
8f7234f
fb13185
c6f8f84
058eba2
d0ba755
f90da5a
255d19f
06f5c87
a23ab36
058eba2
255d19f
f90da5a
d0ba755
a6c8097
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app.py
import os, torch
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from huggingface_hub import login, snapshot_download
import gradio as gr

# -------------------------------
# 1. 模型設定(中文 T5 / Pegasus)
# -------------------------------
PRIMARY_MODEL = "imxly/t5-pegasus-small"   # 適合中文摘要/生成
FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall"  # 若 T5 無法下載就 fallback GPT2

HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")

def try_download_model(repo_id):
    local_dir = f"./models/{repo_id.split('/')[-1]}"
    if not os.path.exists(local_dir):
        print(f"⬇️ 嘗試下載模型 {repo_id} ...")
        try:
            snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
        except Exception as e:
            print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
            return None
    return local_dir

LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
if LOCAL_MODEL_DIR is None:
    print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
    LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
    MODEL_NAME = FALLBACK_MODEL
else:
    MODEL_NAME = PRIMARY_MODEL

print(f"👉 最終使用模型:{MODEL_NAME}")

# -------------------------------
# 2. pipeline 載入
# -------------------------------
tokenizer = AutoTokenizer.from_pretrained(
    LOCAL_MODEL_DIR,
    use_fast=False  # 防止 sentencepiece 問題
)

# 判斷 GPU (CL3) 或 CPU
device = 0 if torch.cuda.is_available() else -1
print(f"💻 使用裝置:{'GPU' if device == 0 else 'CPU'}")

try:
    model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
except:
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)

generator = pipeline(
    "text2text-generation" if "t5" in MODEL_NAME or "pegasus" in MODEL_NAME else "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device
)

def call_local_inference(prompt, max_new_tokens=256):
    try:
        outputs = generator(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7
        )
        return outputs[0]["generated_text"]
    except Exception as e:
        return f"(生成失敗:{e})"

# -------------------------------
# 3. FAISS 向量資料庫載入
# -------------------------------
DB_PATH = "./faiss_db"
EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)

if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
    print("✅ 載入現有向量資料庫...")
    db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
    print("⚠️ 找不到向量資料庫,將建立空的 DB")
    db = FAISS.from_documents([], embeddings_model)

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})

# -------------------------------
# 4. 文章生成(結合 RAG)
# -------------------------------
def generate_article_progress(query, segments=3):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)

    all_text = []

    retrieved_docs = retriever.get_relevant_documents(query)
    context_texts = [d.page_content for d in retrieved_docs]
    context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])

    for i in range(segments):
        prompt = (
            f"以下是佛教經論的相關內容:\n{context}\n\n"
            f"請依據上面內容,寫一段約150-200字的中文文章,"
            f"主題:{query}。\n第{i+1}段:"
        )
        paragraph = call_local_inference(prompt)
        all_text.append(paragraph)
        doc.add_paragraph(paragraph)

        yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"

    doc.save(docx_file)
    yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"

# -------------------------------
# 5. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
    gr.Markdown("使用 Hugging Face 本地模型 + FAISS RAG,僅基於資料庫生成文章。")

    query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
    segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
    output_text = gr.Textbox(label="生成文章")
    output_file = gr.File(label="下載 DOCX")
    status_info = gr.Label(label="狀態")

    btn = gr.Button("生成文章")
    btn.click(
        generate_article_progress,
        inputs=[query_input, segments_input],
        outputs=[output_text, output_file, status_info]
    )

if __name__ == "__main__":
    demo.launch()