File size: 4,641 Bytes
d0ba755
 
 
 
 
 
 
 
299f87b
d0ba755
 
 
 
 
 
 
 
 
 
 
 
 
299f87b
 
 
d0ba755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299f87b
 
 
d0ba755
 
 
 
 
 
 
 
299f87b
d0ba755
299f87b
d0ba755
299f87b
 
d0ba755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299f87b
d0ba755
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py
# -------------------------------
# 1. 套件載入
# -------------------------------
import os, glob, requests
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from docx import Document as DocxDocument
import gradio as gr
from langchain_community.vectorstores import FAISS

# -------------------------------
# 2. 環境變數與資料路徑
# -------------------------------
TXT_FOLDER = "./out_texts"
DB_PATH = "./faiss_db"
os.makedirs(DB_PATH, exist_ok=True)

HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not HF_TOKEN:
    raise ValueError(
        "請在 Hugging Face Space 的 Settings → Repository secrets 設定 HUGGINGFACEHUB_API_TOKEN"
    )

# -------------------------------
# 3. 建立或載入向量資料庫
# -------------------------------
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)

if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
    print("載入現有向量資料庫...")
    db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
else:
    print("沒有資料庫,開始建立新向量資料庫...")
    txt_files = glob.glob(f"{TXT_FOLDER}/*.txt")
    docs = []
    for filepath in txt_files:
        with open(filepath, "r", encoding="utf-8") as f:
            docs.append(
                Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)})
            )
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    split_docs = splitter.split_documents(docs)
    db = FAISS.from_documents(split_docs, embeddings_model)
    db.save_local(DB_PATH)

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})

# -------------------------------
# 4. LLM 設定(Hugging Face Endpoint)
# -------------------------------
llm = HuggingFaceEndpoint(
    repo_id="google/flan-t5-large",
    task="text2text-generation",   # 明確指定 task
    huggingfacehub_api_token=HF_TOKEN,
    model_kwargs={"temperature": 0.7, "max_new_tokens": 512},
)

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    return_source_documents=True
)

# -------------------------------
# 5. 查詢 API 剩餘額度
# -------------------------------
def get_hf_rate_limit():
    headers = {"Authorization": f"Bearer {HF_TOKEN}"}
    try:
        r = requests.get("https://huggingface.co/api/whoami", headers=headers)
        r.raise_for_status()
        data = r.json()
        used = data.get("rate_limit", {}).get("used", 0)
        remaining = 300 - used if used is not None else "未知"
        return f"本小時剩餘 API 次數:約 {remaining}"
    except:
        return "無法取得 API 速率資訊"

# -------------------------------
# 6. 生成文章
# -------------------------------
def generate_article_with_rate(query, segments=5):
    docx_file = "/tmp/generated_article.docx"
    doc = DocxDocument()
    doc.add_heading(query, level=1)

    all_text = []
    prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"

    for i in range(int(segments)):
        try:
            result = qa_chain({"query": prompt})
            paragraph = result["result"].strip()
            if not paragraph:
                paragraph = "(本段生成失敗,請嘗試減少段落或改用較小模型。)"
        except Exception as e:
            paragraph = f"(本段生成失敗:{e})"
        all_text.append(paragraph)
        doc.add_paragraph(paragraph)
        prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"

    doc.save(docx_file)
    full_text = "\n\n".join(all_text)

    rate_info = get_hf_rate_limit()
    return f"{rate_info}\n\n{full_text}", docx_file

# -------------------------------
# 7. Gradio 介面
# -------------------------------
iface = gr.Interface(
    fn=generate_article_with_rate,
    inputs=[
        gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題"),
        gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
    ],
    outputs=[
        gr.Textbox(label="生成文章 + API 剩餘次數"),
        gr.File(label="下載 DOCX")
    ],
    title="佛教經論 RAG 系統 (HF API)",
    description="使用 Hugging Face Endpoint LLM + FAISS RAG,生成文章並提示 API 剩餘額度。"
)

if __name__ == "__main__":
    iface.launch()