CHUNYU0505 commited on
Commit
8be1b46
·
verified ·
1 Parent(s): 58ba158

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -97
app.py CHANGED
@@ -1,7 +1,4 @@
1
  # app.py
2
- # -------------------------------
3
- # 1. 套件載入
4
- # -------------------------------
5
  import os, glob
6
  from langchain.docstore.document import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -13,28 +10,21 @@ from huggingface_hub import login, snapshot_download
13
  import gradio as gr
14
 
15
  # -------------------------------
16
- # 2. 環境變數與登入
17
- # -------------------------------
18
- HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
- if HF_TOKEN:
20
- login(token=HF_TOKEN)
21
- print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
- else:
23
- print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,部分 gated 模型可能無法下載")
24
-
25
- # -------------------------------
26
- # 3. 模型清單(CPU 免費可跑)
27
  # -------------------------------
28
  MODEL_MAP = {
29
  "Auto": None,
30
- "Gemma-2B": "google/gemma-2b", # gated,需要 Access repository
31
- "BTLM-3B-8K": "cerebras/btlm-3b-8k-base", # 正確 repo
32
- "DistilGPT2": "distilgpt2", # 小模型
33
- "BART-Base": "facebook/bart-base" # 小模型
34
  }
35
 
 
 
 
 
36
  # -------------------------------
37
- # 4. 預先下載模型到本地 ./models/
38
  # -------------------------------
39
  LOCAL_MODEL_DIRS = {}
40
  for name, repo in MODEL_MAP.items():
@@ -44,11 +34,7 @@ for name, repo in MODEL_MAP.items():
44
  local_dir = f"./models/{repo.split('/')[-1]}"
45
  if not os.path.exists(local_dir):
46
  print(f"⬇️ 正在下載模型 {repo} ...")
47
- snapshot_download(
48
- repo_id=repo,
49
- token=HF_TOKEN,
50
- local_dir=local_dir
51
- )
52
  else:
53
  print(f"✅ 已存在模型 {repo} -> {local_dir}")
54
  LOCAL_MODEL_DIRS[name] = local_dir
@@ -56,63 +42,13 @@ for name, repo in MODEL_MAP.items():
56
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
57
 
58
  # -------------------------------
59
- # 5. 模型可用性檢查
60
- # -------------------------------
61
- def test_models():
62
- print("\n🔍 啟動模型檢查:")
63
- for name, local_dir in LOCAL_MODEL_DIRS.items():
64
- try:
65
- _ = pipeline(
66
- "text-generation",
67
- model=local_dir,
68
- tokenizer=local_dir,
69
- device_map="cpu"
70
- )
71
- print(f"✅ 模型 {name} 可用")
72
- except Exception as e:
73
- print(f"❌ 模型 {name} 無法載入: {e}")
74
-
75
- test_models()
76
-
77
- # -------------------------------
78
- # 6. 建立或載入向量資料庫
79
- # -------------------------------
80
- TXT_FOLDER = "./out_texts"
81
- DB_PATH = "./faiss_db"
82
- os.makedirs(DB_PATH, exist_ok=True)
83
- os.makedirs(TXT_FOLDER, exist_ok=True)
84
-
85
- EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
86
- embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
87
-
88
- if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
89
- print("載入現有向量資料庫...")
90
- db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
91
- else:
92
- print("沒有資料庫,開始建立新向量資料庫...")
93
- txt_files = glob.glob(f"{TXT_FOLDER}/*.txt")
94
- docs = []
95
- for filepath in txt_files:
96
- with open(filepath, "r", encoding="utf-8") as f:
97
- docs.append(Document(page_content=f.read(), metadata={"source": os.path.basename(filepath)}))
98
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
99
- split_docs = splitter.split_documents(docs)
100
- db = FAISS.from_documents(split_docs, embeddings_model)
101
- db.save_local(DB_PATH)
102
-
103
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
104
-
105
- # -------------------------------
106
- # 7. 本地 pipeline
107
  # -------------------------------
108
  _loaded_pipelines = {}
109
 
110
  def get_pipeline(model_name):
111
  if model_name not in _loaded_pipelines:
112
  local_path = LOCAL_MODEL_DIRS.get(model_name)
113
- if not local_path:
114
- raise ValueError(f"❌ 模型 {model_name} 尚未下載")
115
- print(f"🔄 正在載入模型 {model_name} from {local_path}")
116
  generator = pipeline(
117
  "text-generation",
118
  model=local_path,
@@ -131,56 +67,46 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
131
  return f"(生成失敗:{e})"
132
 
133
  # -------------------------------
134
- # 8. Auto 模式邏輯
135
  # -------------------------------
136
  def pick_model_auto(segments):
137
  if segments <= 3:
138
- return "DistilGPT2" # 短文用最小模型,快
139
  elif segments <= 6:
140
- return "Gemma-2B" # 中篇用 Gemma-2B
141
- elif segments <= 8:
142
- return "BTLM-3B-8K" # 長文用 BTLM
143
  else:
144
- return "BART-Base" # 超長用 Bart-base
145
 
146
  def generate_article_progress(query, model_name, segments=5):
147
  docx_file = "/tmp/generated_article.docx"
148
  doc = DocxDocument()
149
  doc.add_heading(query, level=1)
150
 
151
- if model_name == "Auto":
152
- selected_model = pick_model_auto(int(segments))
153
- else:
154
- selected_model = model_name
155
  print(f"👉 使用模型: {selected_model}")
156
 
157
  all_text = []
158
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
159
 
160
- for i in range(int(segments)):
161
  paragraph = call_local_inference(selected_model, prompt)
162
  all_text.append(paragraph)
163
  doc.add_paragraph(paragraph)
164
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
165
-
166
  yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
167
 
168
  doc.save(docx_file)
169
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
170
 
171
  # -------------------------------
172
- # 9. Gradio 介面
173
  # -------------------------------
174
  with gr.Blocks() as demo:
175
  gr.Markdown("# 佛教經論 RAG 系統 (CPU 免費版)")
176
- gr.Markdown("支援 Gemma-2B / BTLM-3B / DistilGPT2 / BART-Base,Auto 模式會自動選擇。")
177
 
178
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
179
- model_dropdown = gr.Dropdown(
180
- choices=list(MODEL_MAP.keys()),
181
- value="Auto",
182
- label="選擇生成模型"
183
- )
184
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
185
  output_text = gr.Textbox(label="生成文章")
186
  output_file = gr.File(label="下載 DOCX")
@@ -193,8 +119,5 @@ with gr.Blocks() as demo:
193
  outputs=[output_text, output_file, model_used_text]
194
  )
195
 
196
- # -------------------------------
197
- # 10. 啟動 Gradio
198
- # -------------------------------
199
  if __name__ == "__main__":
200
  demo.launch()
 
1
  # app.py
 
 
 
2
  import os, glob
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型清單(全部公開)
 
 
 
 
 
 
 
 
 
 
14
  # -------------------------------
15
  MODEL_MAP = {
16
  "Auto": None,
17
+ "BTLM-3B-8K": "cerebras/btlm-3b-8k-base", # 3B 模型,公開
18
+ "DistilGPT2": "distilgpt2", # 小模型
19
+ "BART-Base": "facebook/bart-base" # 小模型
 
20
  }
21
 
22
+ HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
23
+ if HF_TOKEN:
24
+ login(token=HF_TOKEN)
25
+
26
  # -------------------------------
27
+ # 2. 預先下載模型到 ./models/
28
  # -------------------------------
29
  LOCAL_MODEL_DIRS = {}
30
  for name, repo in MODEL_MAP.items():
 
34
  local_dir = f"./models/{repo.split('/')[-1]}"
35
  if not os.path.exists(local_dir):
36
  print(f"⬇️ 正在下載模型 {repo} ...")
37
+ snapshot_download(repo_id=repo, token=HF_TOKEN, local_dir=local_dir)
 
 
 
 
38
  else:
39
  print(f"✅ 已存在模型 {repo} -> {local_dir}")
40
  LOCAL_MODEL_DIRS[name] = local_dir
 
42
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
43
 
44
  # -------------------------------
45
+ # 3. pipeline 載入
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # -------------------------------
47
  _loaded_pipelines = {}
48
 
49
  def get_pipeline(model_name):
50
  if model_name not in _loaded_pipelines:
51
  local_path = LOCAL_MODEL_DIRS.get(model_name)
 
 
 
52
  generator = pipeline(
53
  "text-generation",
54
  model=local_path,
 
67
  return f"(生成失敗:{e})"
68
 
69
  # -------------------------------
70
+ # 4. Auto 模式邏輯
71
  # -------------------------------
72
  def pick_model_auto(segments):
73
  if segments <= 3:
74
+ return "DistilGPT2"
75
  elif segments <= 6:
76
+ return "BTLM-3B-8K"
 
 
77
  else:
78
+ return "BART-Base"
79
 
80
  def generate_article_progress(query, model_name, segments=5):
81
  docx_file = "/tmp/generated_article.docx"
82
  doc = DocxDocument()
83
  doc.add_heading(query, level=1)
84
 
85
+ selected_model = pick_model_auto(segments) if model_name == "Auto" else model_name
 
 
 
86
  print(f"👉 使用模型: {selected_model}")
87
 
88
  all_text = []
89
  prompt = f"請依據下列主題生成段落:{query}\n\n每段約150-200字。"
90
 
91
+ for i in range(segments):
92
  paragraph = call_local_inference(selected_model, prompt)
93
  all_text.append(paragraph)
94
  doc.add_paragraph(paragraph)
95
  prompt = f"請接續上一段生成下一段:\n{paragraph}\n\n下一段:"
 
96
  yield "\n\n".join(all_text), None, f"本次使用模型:{selected_model}"
97
 
98
  doc.save(docx_file)
99
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
100
 
101
  # -------------------------------
102
+ # 5. Gradio 介面
103
  # -------------------------------
104
  with gr.Blocks() as demo:
105
  gr.Markdown("# 佛教經論 RAG 系統 (CPU 免費版)")
106
+ gr.Markdown("支援 DistilGPT2 / BTLM-3B / BART-Base,Auto 模式會自動選擇。")
107
 
108
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
109
+ model_dropdown = gr.Dropdown(choices=list(MODEL_MAP.keys()), value="Auto", label="選擇生成模型")
 
 
 
 
110
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
111
  output_text = gr.Textbox(label="生成文章")
112
  output_file = gr.File(label="下載 DOCX")
 
119
  outputs=[output_text, output_file, model_used_text]
120
  )
121
 
 
 
 
122
  if __name__ == "__main__":
123
  demo.launch()