CHUNYU0505 commited on
Commit
6b1b850
·
verified ·
1 Parent(s): 76b0768

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -20,17 +20,17 @@ if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
  else:
23
- print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,Gemma-7B 可能無法下載")
24
 
25
  # -------------------------------
26
- # 3. 模型清單
27
  # -------------------------------
28
  MODEL_MAP = {
29
  "Auto": None,
30
  "Gemma-2B": "google/gemma-2b",
31
- "Gemma-7B": "google/gemma-7b", # gated
32
- "BTLM-3B-8K": "cerebras/btlm-3b-8k",
33
- "Mistral-7B": "mistralai/Mistral-7B-v0.1"
34
  }
35
 
36
  # -------------------------------
@@ -38,7 +38,7 @@ MODEL_MAP = {
38
  # -------------------------------
39
  LOCAL_MODEL_DIRS = {}
40
  for name, repo in MODEL_MAP.items():
41
- if repo is None: # Auto 跳過
42
  continue
43
  try:
44
  local_dir = f"./models/{repo.split('/')[-1]}"
@@ -56,7 +56,26 @@ for name, repo in MODEL_MAP.items():
56
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
57
 
58
  # -------------------------------
59
- # 5. 建立或載入向量資料庫
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # -------------------------------
61
  TXT_FOLDER = "./out_texts"
62
  DB_PATH = "./faiss_db"
@@ -84,7 +103,7 @@ else:
84
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
85
 
86
  # -------------------------------
87
- # 6. 本地 pipeline
88
  # -------------------------------
89
  _loaded_pipelines = {}
90
 
@@ -98,7 +117,7 @@ def get_pipeline(model_name):
98
  "text-generation",
99
  model=local_path,
100
  tokenizer=local_path,
101
- device_map="auto"
102
  )
103
  _loaded_pipelines[model_name] = generator
104
  return _loaded_pipelines[model_name]
@@ -112,7 +131,7 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
112
  return f"(生成失敗:{e})"
113
 
114
  # -------------------------------
115
- # 7. Auto 模式邏輯
116
  # -------------------------------
117
  def pick_model_auto(segments):
118
  if segments <= 3:
@@ -120,7 +139,7 @@ def pick_model_auto(segments):
120
  elif segments <= 6:
121
  return "BTLM-3B-8K"
122
  else:
123
- return "Mistral-7B"
124
 
125
  def generate_article_progress(query, model_name, segments=5):
126
  docx_file = "/tmp/generated_article.docx"
@@ -148,11 +167,11 @@ def generate_article_progress(query, model_name, segments=5):
148
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
149
 
150
  # -------------------------------
151
- # 8. Gradio 介面
152
  # -------------------------------
153
  with gr.Blocks() as demo:
154
- gr.Markdown("# 佛教經論 RAG 系統 (本地模型)")
155
- gr.Markdown("支援 Gemma / BTLM / Mistral,Auto 模式會自動選擇模型。")
156
 
157
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
158
  model_dropdown = gr.Dropdown(
@@ -173,7 +192,7 @@ with gr.Blocks() as demo:
173
  )
174
 
175
  # -------------------------------
176
- # 9. 啟動 Gradio
177
  # -------------------------------
178
  if __name__ == "__main__":
179
  demo.launch()
 
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
  else:
23
+ print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,部分 gated 模型可能無法下載")
24
 
25
  # -------------------------------
26
+ # 3. 模型清單(CPU 免費可跑)
27
  # -------------------------------
28
  MODEL_MAP = {
29
  "Auto": None,
30
  "Gemma-2B": "google/gemma-2b",
31
+ "BTLM-3B-8K": "tiiuae/btlm-3b-8k-base",
32
+ "DistilGPT2": "distilgpt2",
33
+ "BART-Base": "facebook/bart-base"
34
  }
35
 
36
  # -------------------------------
 
38
  # -------------------------------
39
  LOCAL_MODEL_DIRS = {}
40
  for name, repo in MODEL_MAP.items():
41
+ if repo is None:
42
  continue
43
  try:
44
  local_dir = f"./models/{repo.split('/')[-1]}"
 
56
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
57
 
58
  # -------------------------------
59
+ # 5. 模型可用性檢查
60
+ # -------------------------------
61
+ def test_models():
62
+ print("\n🔍 啟動模型檢查:")
63
+ for name, local_dir in LOCAL_MODEL_DIRS.items():
64
+ try:
65
+ _ = pipeline(
66
+ "text-generation",
67
+ model=local_dir,
68
+ tokenizer=local_dir,
69
+ device_map="cpu"
70
+ )
71
+ print(f"✅ 模型 {name} ({local_dir}) 可用")
72
+ except Exception as e:
73
+ print(f"❌ 模型 {name} ({local_dir}) 無法載入: {e}")
74
+
75
+ test_models()
76
+
77
+ # -------------------------------
78
+ # 6. 建立或載入向量資料庫
79
  # -------------------------------
80
  TXT_FOLDER = "./out_texts"
81
  DB_PATH = "./faiss_db"
 
103
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
104
 
105
  # -------------------------------
106
+ # 7. 本地 pipeline
107
  # -------------------------------
108
  _loaded_pipelines = {}
109
 
 
117
  "text-generation",
118
  model=local_path,
119
  tokenizer=local_path,
120
+ device_map="cpu"
121
  )
122
  _loaded_pipelines[model_name] = generator
123
  return _loaded_pipelines[model_name]
 
131
  return f"(生成失敗:{e})"
132
 
133
  # -------------------------------
134
+ # 8. Auto 模式邏輯
135
  # -------------------------------
136
  def pick_model_auto(segments):
137
  if segments <= 3:
 
139
  elif segments <= 6:
140
  return "BTLM-3B-8K"
141
  else:
142
+ return "BART-Base"
143
 
144
  def generate_article_progress(query, model_name, segments=5):
145
  docx_file = "/tmp/generated_article.docx"
 
167
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
168
 
169
  # -------------------------------
170
+ # 9. Gradio 介面
171
  # -------------------------------
172
  with gr.Blocks() as demo:
173
+ gr.Markdown("# 佛教經論 RAG 系統 (CPU 免費版)")
174
+ gr.Markdown("支援 Gemma-2B / BTLM-3B / DistilGPT2 / BART-Base,Auto 模式會自動選擇。")
175
 
176
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
177
  model_dropdown = gr.Dropdown(
 
192
  )
193
 
194
  # -------------------------------
195
+ # 10. 啟動 Gradio
196
  # -------------------------------
197
  if __name__ == "__main__":
198
  demo.launch()