CHUNYU0505 commited on
Commit
76b0768
·
verified ·
1 Parent(s): 38fd239

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -9,7 +9,7 @@ from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from docx import Document as DocxDocument
11
  from transformers import pipeline
12
- from huggingface_hub import login
13
  import gradio as gr
14
 
15
  # -------------------------------
@@ -20,16 +20,49 @@ if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
  else:
23
- print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,Gemma-7B 等 gated 模型可能無法使用")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  TXT_FOLDER = "./out_texts"
26
  DB_PATH = "./faiss_db"
27
  os.makedirs(DB_PATH, exist_ok=True)
28
  os.makedirs(TXT_FOLDER, exist_ok=True)
29
 
30
- # -------------------------------
31
- # 3. 建立或載入向量資料庫
32
- # -------------------------------
33
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
34
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
35
 
@@ -51,28 +84,21 @@ else:
51
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
52
 
53
  # -------------------------------
54
- # 4. 本地推論模型設定
55
  # -------------------------------
56
- MODEL_MAP = {
57
- "Auto": None,
58
- "Gemma-2B": "google/gemma-2b",
59
- "Gemma-7B": "google/gemma-7b", # gated,需要 HF_TOKEN
60
- "BTLM-3B-8K": "cerebras/btlm-3b-8k",
61
- "Mistral-7B": "mistralai/Mistral-7B-v0.1"
62
- }
63
-
64
  _loaded_pipelines = {}
65
 
66
  def get_pipeline(model_name):
67
  if model_name not in _loaded_pipelines:
68
- print(f"🔄 正在載入模型 {model_name} ...")
69
- model_id = MODEL_MAP[model_name]
 
 
70
  generator = pipeline(
71
  "text-generation",
72
- model=model_id,
73
- tokenizer=model_id,
74
- device_map="auto",
75
- token=HF_TOKEN # gated 模型會用這個
76
  )
77
  _loaded_pipelines[model_name] = generator
78
  return _loaded_pipelines[model_name]
@@ -86,23 +112,21 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
86
  return f"(生成失敗:{e})"
87
 
88
  # -------------------------------
89
- # 5. 自動選模型邏輯
90
  # -------------------------------
91
  def pick_model_auto(segments):
92
- """根據段落數自動挑選模型"""
93
  if segments <= 3:
94
  return "Gemma-2B"
95
  elif segments <= 6:
96
  return "BTLM-3B-8K"
97
  else:
98
- return "Mistral-7B" # 避免 gpt-oss-20B 太大跑不動
99
 
100
  def generate_article_progress(query, model_name, segments=5):
101
  docx_file = "/tmp/generated_article.docx"
102
  doc = DocxDocument()
103
  doc.add_heading(query, level=1)
104
 
105
- # 自動挑模型
106
  if model_name == "Auto":
107
  selected_model = pick_model_auto(int(segments))
108
  else:
@@ -124,11 +148,11 @@ def generate_article_progress(query, model_name, segments=5):
124
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
125
 
126
  # -------------------------------
127
- # 6. Gradio 介面
128
  # -------------------------------
129
  with gr.Blocks() as demo:
130
- gr.Markdown("# 佛教經論 RAG 系統 (Gemma / BTLM / Mistral)")
131
- gr.Markdown("支援 Auto 模式,並顯示實際使用的模型名稱。")
132
 
133
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
134
  model_dropdown = gr.Dropdown(
@@ -149,7 +173,7 @@ with gr.Blocks() as demo:
149
  )
150
 
151
  # -------------------------------
152
- # 7. 啟動 Gradio
153
  # -------------------------------
154
  if __name__ == "__main__":
155
  demo.launch()
 
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from docx import Document as DocxDocument
11
  from transformers import pipeline
12
+ from huggingface_hub import login, snapshot_download
13
  import gradio as gr
14
 
15
  # -------------------------------
 
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
  else:
23
+ print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,Gemma-7B 可能無法下載")
24
 
25
+ # -------------------------------
26
+ # 3. 模型清單
27
+ # -------------------------------
28
+ MODEL_MAP = {
29
+ "Auto": None,
30
+ "Gemma-2B": "google/gemma-2b",
31
+ "Gemma-7B": "google/gemma-7b", # gated
32
+ "BTLM-3B-8K": "cerebras/btlm-3b-8k",
33
+ "Mistral-7B": "mistralai/Mistral-7B-v0.1"
34
+ }
35
+
36
+ # -------------------------------
37
+ # 4. 預先下載模型到本地 ./models/
38
+ # -------------------------------
39
+ LOCAL_MODEL_DIRS = {}
40
+ for name, repo in MODEL_MAP.items():
41
+ if repo is None: # Auto 跳過
42
+ continue
43
+ try:
44
+ local_dir = f"./models/{repo.split('/')[-1]}"
45
+ if not os.path.exists(local_dir):
46
+ print(f"⬇️ 正在下載模型 {repo} ...")
47
+ snapshot_download(
48
+ repo_id=repo,
49
+ token=HF_TOKEN,
50
+ local_dir=local_dir
51
+ )
52
+ else:
53
+ print(f"✅ 已存在模型 {repo} -> {local_dir}")
54
+ LOCAL_MODEL_DIRS[name] = local_dir
55
+ except Exception as e:
56
+ print(f"⚠️ 模型 {repo} 無法下載: {e}")
57
+
58
+ # -------------------------------
59
+ # 5. 建立或載入向量資料庫
60
+ # -------------------------------
61
  TXT_FOLDER = "./out_texts"
62
  DB_PATH = "./faiss_db"
63
  os.makedirs(DB_PATH, exist_ok=True)
64
  os.makedirs(TXT_FOLDER, exist_ok=True)
65
 
 
 
 
66
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
67
  embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
68
 
 
84
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
85
 
86
  # -------------------------------
87
+ # 6. 本地 pipeline
88
  # -------------------------------
 
 
 
 
 
 
 
 
89
  _loaded_pipelines = {}
90
 
91
  def get_pipeline(model_name):
92
  if model_name not in _loaded_pipelines:
93
+ local_path = LOCAL_MODEL_DIRS.get(model_name)
94
+ if not local_path:
95
+ raise ValueError(f"❌ 模型 {model_name} 尚未下載")
96
+ print(f"🔄 正在載入本地模型 {model_name} from {local_path}")
97
  generator = pipeline(
98
  "text-generation",
99
+ model=local_path,
100
+ tokenizer=local_path,
101
+ device_map="auto"
 
102
  )
103
  _loaded_pipelines[model_name] = generator
104
  return _loaded_pipelines[model_name]
 
112
  return f"(生成失敗:{e})"
113
 
114
  # -------------------------------
115
+ # 7. Auto 模式邏輯
116
  # -------------------------------
117
  def pick_model_auto(segments):
 
118
  if segments <= 3:
119
  return "Gemma-2B"
120
  elif segments <= 6:
121
  return "BTLM-3B-8K"
122
  else:
123
+ return "Mistral-7B"
124
 
125
  def generate_article_progress(query, model_name, segments=5):
126
  docx_file = "/tmp/generated_article.docx"
127
  doc = DocxDocument()
128
  doc.add_heading(query, level=1)
129
 
 
130
  if model_name == "Auto":
131
  selected_model = pick_model_auto(int(segments))
132
  else:
 
148
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
149
 
150
  # -------------------------------
151
+ # 8. Gradio 介面
152
  # -------------------------------
153
  with gr.Blocks() as demo:
154
+ gr.Markdown("# 佛教經論 RAG 系統 (本地模型)")
155
+ gr.Markdown("支援 Gemma / BTLM / Mistral,Auto 模式會自動選擇模型。")
156
 
157
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
158
  model_dropdown = gr.Dropdown(
 
173
  )
174
 
175
  # -------------------------------
176
+ # 9. 啟動 Gradio
177
  # -------------------------------
178
  if __name__ == "__main__":
179
  demo.launch()