CHUNYU0505 commited on
Commit
ffa9279
·
verified ·
1 Parent(s): 4695cce

使用 Hugging Face API 大模型生成文章

Browse files
Files changed (1) hide show
  1. app.py +61 -82
app.py CHANGED
@@ -1,135 +1,114 @@
1
- # -------------------------------
2
- # 1. 匯入套件
3
- # -------------------------------
4
- import os, glob, time
5
  from langchain_community.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
- from langchain_core.documents import Document
9
- from langchain_community.chat_models import ChatHuggingFaceHub
10
  from langchain.chains import RetrievalQA
11
-
12
  from docx import Document as DocxDocument
13
  import gradio as gr
14
 
 
 
 
 
15
 
16
  # -------------------------------
17
- # 2. 設定路徑
18
  # -------------------------------
19
- txt_folder = "out_texts" # 放你的 .txt 檔
20
- db_path = "faiss_db"
21
  os.makedirs(db_path, exist_ok=True)
22
 
23
  # -------------------------------
24
- # 3. 建立 embeddings
25
  # -------------------------------
26
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
27
 
28
  # -------------------------------
29
- # 4. 建立或載入向量資料庫
30
  # -------------------------------
31
  if os.path.exists(os.path.join(db_path, "index.faiss")):
32
  print("載入現有向量資料庫...")
33
  db = FAISS.load_local(db_path, embeddings_model, allow_dangerous_deserialization=True)
34
  else:
35
- print("沒有資料庫,開始建立新向量資料庫...")
36
  txt_files = glob.glob(f"{txt_folder}/*.txt")
37
  docs = []
38
- for filepath in txt_files:
39
- with open(filepath, "r", encoding="utf-8") as f:
40
- docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)}))
41
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
42
  split_docs = text_splitter.split_documents(docs)
43
- print("產生向量嵌入中...")
44
  db = FAISS.from_documents(split_docs, embeddings_model)
45
  db.save_local(db_path)
46
  print("向量資料庫建立完成。")
47
 
 
 
48
  # -------------------------------
49
- # 5. Hugging Face 模型設定
50
  # -------------------------------
51
- HUGGINGFACE_API_TOKEN = os.getenv("HF_TOKEN") # 建議在 Spaces Secrets 設定
52
-
53
  MODEL_DICT = {
54
- "google/flan-t5-large": 512,
55
- "tiiuae/falcon-7b-instruct": 512
 
56
  }
57
 
58
- MAX_HOURLY_REQUESTS = 50
59
- request_count = 0
60
- last_reset_time = time.time()
61
-
62
- # -------------------------------
63
- # 6. RAG 主函式
64
- # -------------------------------
65
- def rag_generate_hfapi(query, model_name, segments=5, max_words=1500):
66
- global request_count, last_reset_time
67
- if time.time() - last_reset_time > 3600:
68
- request_count = 0
69
- last_reset_time = time.time()
70
-
71
- if request_count >= MAX_HOURLY_REQUESTS:
72
- return f"本小時生成次數已達上限 ({MAX_HOURLY_REQUESTS}),請稍後再試。", None
73
-
74
- llm = ChatHuggingFaceHub(
75
  repo_id=model_name,
76
- model_kwargs={"temperature": 0.7, "max_new_tokens": MODEL_DICT[model_name]},
77
- huggingfacehub_api_token=HUGGINGFACE_API_TOKEN
78
- )
79
-
80
- qa_chain = RetrievalQA.from_chain_type(
81
- llm=llm,
82
- retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 5}),
83
- return_source_documents=True
84
  )
85
 
86
- prompt = f"""請依據下列主題生成一篇文章:
87
- 主題:{query}
88
- 需求:
89
- - 總共 {segments} 段
90
- - 每段約 {max_words // segments} 字
91
- - 總字數請控制在 {max_words} 字以內
92
- - 請自動分段輸出
93
- """
94
-
95
- try:
96
- result = qa_chain({"query": prompt})
97
- full_text = result["result"].strip()
98
- if not full_text:
99
- full_text = "(生成失敗,請改用其他模型或調整段落數)"
100
- except Exception as e:
101
- return f"(生成失敗:{str(e)})", None
102
-
103
- request_count += 1
104
-
105
- paragraphs = [p.strip() for p in full_text.split("\n") if p.strip()]
106
-
107
- docx_file = "generated_article.docx"
108
  doc = DocxDocument()
109
  doc.add_heading(query, level=1)
110
- for p in paragraphs:
111
- doc.add_paragraph(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  doc.save(docx_file)
 
113
 
114
- return "\n\n".join(paragraphs), docx_file
 
 
 
 
115
 
116
  # -------------------------------
117
  # 7. Gradio 介面
118
  # -------------------------------
119
  iface = gr.Interface(
120
- fn=rag_generate_hfapi,
121
  inputs=[
122
  gr.Textbox(lines=2, placeholder="請輸入文章主題"),
123
- gr.Dropdown(list(MODEL_DICT.keys()), value="google/flan-t5-large", label="選擇模型"),
124
- gr.Slider(minimum=1, maximum=10, value=5, step=1, label="段落數"),
125
- gr.Slider(minimum=500, maximum=3000, value=1500, step=100, label="文章字數上限")
126
- ],
127
- outputs=[
128
- gr.Textbox(label="生成文章"),
129
- gr.File(label="下載 DOCX")
130
  ],
 
131
  title="佛教經論 RAG 系統 (Hugging Face API)",
132
- description="使用 Hugging Face API 生成文章,可選大模型,分段生成並下載 DOCX,每小時生成次數有限制"
133
  )
134
 
135
  iface.launch()
 
1
+ import os, glob, time, requests
 
 
 
2
  from langchain_community.text_splitter import RecursiveCharacterTextSplitter
 
3
  from langchain_community.vectorstores import FAISS
4
+ from langchain.docstore.document import Document
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
  from langchain.chains import RetrievalQA
7
+ from langchain_huggingface import HuggingFaceHub
8
  from docx import Document as DocxDocument
9
  import gradio as gr
10
 
11
+ # -------------------------------
12
+ # 1. Hugging Face API Key
13
+ # -------------------------------
14
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN") # 或直接在 Space Secrets 設定 HF_API_TOKEN
15
 
16
  # -------------------------------
17
+ # 2. 資料路徑
18
  # -------------------------------
19
+ txt_folder = "./out_texts"
20
+ db_path = "./faiss_db"
21
  os.makedirs(db_path, exist_ok=True)
22
 
23
  # -------------------------------
24
+ # 3. Embeddings
25
  # -------------------------------
26
  embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
27
 
28
  # -------------------------------
29
+ # 4. 載入或建立向量資料庫
30
  # -------------------------------
31
  if os.path.exists(os.path.join(db_path, "index.faiss")):
32
  print("載入現有向量資料庫...")
33
  db = FAISS.load_local(db_path, embeddings_model, allow_dangerous_deserialization=True)
34
  else:
35
+ print("建立新向量資料庫...")
36
  txt_files = glob.glob(f"{txt_folder}/*.txt")
37
  docs = []
38
+ for fp in txt_files:
39
+ with open(fp, "r", encoding="utf-8") as f:
40
+ docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(fp)}))
41
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
42
  split_docs = text_splitter.split_documents(docs)
 
43
  db = FAISS.from_documents(split_docs, embeddings_model)
44
  db.save_local(db_path)
45
  print("向量資料庫建立完成。")
46
 
47
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":5})
48
+
49
  # -------------------------------
50
+ # 5. 模型選擇
51
  # -------------------------------
 
 
52
  MODEL_DICT = {
53
+ "google/flan-t5-base": "text2text-generation",
54
+ "google/flan-t5-large": "text2text-generation",
55
+ "google/flan-t5-xl": "text2text-generation"
56
  }
57
 
58
+ def load_hf_llm(model_name):
59
+ return HuggingFaceHub(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  repo_id=model_name,
61
+ model_kwargs={"temperature":0.7, "max_new_tokens":512},
62
+ huggingfacehub_api_token=HF_API_TOKEN
 
 
 
 
 
 
63
  )
64
 
65
+ # -------------------------------
66
+ # 6. RAG 生成文章
67
+ # -------------------------------
68
+ def rag_generate_hf(query, model_name, segments=5):
69
+ llm = load_hf_llm(model_name)
70
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
71
+
72
+ docx_file = "./generated_article.docx"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  doc = DocxDocument()
74
  doc.add_heading(query, level=1)
75
+
76
+ all_text = []
77
+ prompt = f"請依據下列主題生成段落:{query}\n每段約150-200字。"
78
+
79
+ for i in range(int(segments)):
80
+ try:
81
+ result = qa_chain({"query": prompt})
82
+ paragraph = result["result"].strip()
83
+ except Exception as e:
84
+ paragraph = f"(本段生成失敗: {e})"
85
+ all_text.append(paragraph)
86
+ doc.add_paragraph(paragraph)
87
+ prompt = f"請接續上一段生成下一段:\n{paragraph}\n下一段:"
88
+ time.sleep(0.5) # 避免 API 速率過快
89
+
90
  doc.save(docx_file)
91
+ full_text = "\n\n".join(all_text)
92
 
93
+ # 顯示 Hugging Face API 限額
94
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
95
+ usage = requests.get("https://api-inference.huggingface.co/usage", headers=headers).json()
96
+ quota = usage.get("model_card", "無法取得額度")
97
+ return full_text + f"\n\n[API 使用額度: {quota}]", docx_file
98
 
99
  # -------------------------------
100
  # 7. Gradio 介面
101
  # -------------------------------
102
  iface = gr.Interface(
103
+ fn=rag_generate_hf,
104
  inputs=[
105
  gr.Textbox(lines=2, placeholder="請輸入文章主題"),
106
+ gr.Dropdown(list(MODEL_DICT.keys()), value="google/flan-t5-base", label="選擇模型"),
107
+ gr.Slider(minimum=1, maximum=10, value=5, step=1, label="段落數")
 
 
 
 
 
108
  ],
109
+ outputs=[gr.Textbox(label="生成文章"), gr.File(label="下載 DOCX")],
110
  title="佛教經論 RAG 系統 (Hugging Face API)",
111
+ description="使用 Hugging Face API 大模型生成文章,可選模型與段落數,生成完成可下載 DOCX"
112
  )
113
 
114
  iface.launch()