Spaces:
Sleeping
Sleeping
File size: 5,371 Bytes
058eba2 8f7234f d051231 c4310e4 d051231 1740855 058eba2 d051231 058eba2 9c1b3ba d051231 058eba2 d051231 06f5c87 058eba2 06f5c87 058eba2 d051231 058eba2 d051231 058eba2 d051231 06f5c87 c4310e4 d051231 058eba2 06f5c87 058eba2 8f7234f 058eba2 06f5c87 d0ba755 8f7234f d0ba755 6f0de2e 058eba2 80fe36a 06f5c87 6f0de2e 8f7234f 06f5c87 6f0de2e a23ab36 132ef2d 06f5c87 6f0de2e 058eba2 6f0de2e 06f5c87 058eba2 80fe36a d0ba755 06f5c87 d0ba755 c6f8f84 06f5c87 058eba2 c6f8f84 8f7234f fb13185 c6f8f84 058eba2 d0ba755 f90da5a 255d19f 06f5c87 a23ab36 058eba2 255d19f f90da5a d0ba755 a6c8097 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# app.py
import os, torch
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from docx import Document as DocxDocument
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from huggingface_hub import login, snapshot_download
import gradio as gr
# -------------------------------
# 1. 模型設定(中文 T5 / Pegasus)
# -------------------------------
PRIMARY_MODEL = "imxly/t5-pegasus-small" # 適合中文摘要/生成
FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall" # 若 T5 無法下載就 fallback GPT2
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
def try_download_model(repo_id):
local_dir = f"./models/{repo_id.split('/')[-1]}"
if not os.path.exists(local_dir):
print(f"⬇️ 嘗試下載模型 {repo_id} ...")
try:
snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
except Exception as e:
print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
return None
return local_dir
LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
if LOCAL_MODEL_DIR is None:
print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
MODEL_NAME = FALLBACK_MODEL
else:
MODEL_NAME = PRIMARY_MODEL
print(f"👉 最終使用模型:{MODEL_NAME}")
# -------------------------------
# 2. pipeline 載入
# -------------------------------
tokenizer = AutoTokenizer.from_pretrained(
LOCAL_MODEL_DIR,
use_fast=False # 防止 sentencepiece 問題
)
# 判斷 GPU (CL3) 或 CPU
device = 0 if torch.cuda.is_available() else -1
print(f"💻 使用裝置:{'GPU' if device == 0 else 'CPU'}")
try:
model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
except:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
generator = pipeline(
"text2text-generation" if "t5" in MODEL_NAME or "pegasus" in MODEL_NAME else "text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
def call_local_inference(prompt, max_new_tokens=256):
try:
outputs = generator(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7
)
return outputs[0]["generated_text"]
except Exception as e:
return f"(生成失敗:{e})"
# -------------------------------
# 3. FAISS 向量資料庫載入
# -------------------------------
DB_PATH = "./faiss_db"
EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
print("✅ 載入現有向量資料庫...")
db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
print("⚠️ 找不到向量資料庫,將建立空的 DB")
db = FAISS.from_documents([], embeddings_model)
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
# -------------------------------
# 4. 文章生成(結合 RAG)
# -------------------------------
def generate_article_progress(query, segments=3):
docx_file = "/tmp/generated_article.docx"
doc = DocxDocument()
doc.add_heading(query, level=1)
all_text = []
retrieved_docs = retriever.get_relevant_documents(query)
context_texts = [d.page_content for d in retrieved_docs]
context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
for i in range(segments):
prompt = (
f"以下是佛教經論的相關內容:\n{context}\n\n"
f"請依據上面內容,寫一段約150-200字的中文文章,"
f"主題:{query}。\n第{i+1}段:"
)
paragraph = call_local_inference(prompt)
all_text.append(paragraph)
doc.add_paragraph(paragraph)
yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
doc.save(docx_file)
yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
# -------------------------------
# 5. Gradio 介面
# -------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
gr.Markdown("使用 Hugging Face 本地模型 + FAISS RAG,僅基於資料庫生成文章。")
query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
output_text = gr.Textbox(label="生成文章")
output_file = gr.File(label="下載 DOCX")
status_info = gr.Label(label="狀態")
btn = gr.Button("生成文章")
btn.click(
generate_article_progress,
inputs=[query_input, segments_input],
outputs=[output_text, output_file, status_info]
)
if __name__ == "__main__":
demo.launch()
|