Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from langchain_community.vectorstores import FAISS
|
|
| 9 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
from docx import Document as DocxDocument
|
| 11 |
from transformers import pipeline
|
| 12 |
-
from huggingface_hub import login
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
# -------------------------------
|
|
@@ -20,16 +20,49 @@ if HF_TOKEN:
|
|
| 20 |
login(token=HF_TOKEN)
|
| 21 |
print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
|
| 22 |
else:
|
| 23 |
-
print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,Gemma-7B
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
TXT_FOLDER = "./out_texts"
|
| 26 |
DB_PATH = "./faiss_db"
|
| 27 |
os.makedirs(DB_PATH, exist_ok=True)
|
| 28 |
os.makedirs(TXT_FOLDER, exist_ok=True)
|
| 29 |
|
| 30 |
-
# -------------------------------
|
| 31 |
-
# 3. 建立或載入向量資料庫
|
| 32 |
-
# -------------------------------
|
| 33 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 34 |
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
|
| 35 |
|
|
@@ -51,28 +84,21 @@ else:
|
|
| 51 |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
| 52 |
|
| 53 |
# -------------------------------
|
| 54 |
-
#
|
| 55 |
# -------------------------------
|
| 56 |
-
MODEL_MAP = {
|
| 57 |
-
"Auto": None,
|
| 58 |
-
"Gemma-2B": "google/gemma-2b",
|
| 59 |
-
"Gemma-7B": "google/gemma-7b", # gated,需要 HF_TOKEN
|
| 60 |
-
"BTLM-3B-8K": "cerebras/btlm-3b-8k",
|
| 61 |
-
"Mistral-7B": "mistralai/Mistral-7B-v0.1"
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
_loaded_pipelines = {}
|
| 65 |
|
| 66 |
def get_pipeline(model_name):
|
| 67 |
if model_name not in _loaded_pipelines:
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
generator = pipeline(
|
| 71 |
"text-generation",
|
| 72 |
-
model=
|
| 73 |
-
tokenizer=
|
| 74 |
-
device_map="auto"
|
| 75 |
-
token=HF_TOKEN # gated 模型會用這個
|
| 76 |
)
|
| 77 |
_loaded_pipelines[model_name] = generator
|
| 78 |
return _loaded_pipelines[model_name]
|
|
@@ -86,23 +112,21 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
|
|
| 86 |
return f"(生成失敗:{e})"
|
| 87 |
|
| 88 |
# -------------------------------
|
| 89 |
-
#
|
| 90 |
# -------------------------------
|
| 91 |
def pick_model_auto(segments):
|
| 92 |
-
"""根據段落數自動挑選模型"""
|
| 93 |
if segments <= 3:
|
| 94 |
return "Gemma-2B"
|
| 95 |
elif segments <= 6:
|
| 96 |
return "BTLM-3B-8K"
|
| 97 |
else:
|
| 98 |
-
return "Mistral-7B"
|
| 99 |
|
| 100 |
def generate_article_progress(query, model_name, segments=5):
|
| 101 |
docx_file = "/tmp/generated_article.docx"
|
| 102 |
doc = DocxDocument()
|
| 103 |
doc.add_heading(query, level=1)
|
| 104 |
|
| 105 |
-
# 自動挑模型
|
| 106 |
if model_name == "Auto":
|
| 107 |
selected_model = pick_model_auto(int(segments))
|
| 108 |
else:
|
|
@@ -124,11 +148,11 @@ def generate_article_progress(query, model_name, segments=5):
|
|
| 124 |
yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
|
| 125 |
|
| 126 |
# -------------------------------
|
| 127 |
-
#
|
| 128 |
# -------------------------------
|
| 129 |
with gr.Blocks() as demo:
|
| 130 |
-
gr.Markdown("# 佛教經論 RAG 系統 (
|
| 131 |
-
gr.Markdown("支援 Auto
|
| 132 |
|
| 133 |
query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
|
| 134 |
model_dropdown = gr.Dropdown(
|
|
@@ -149,7 +173,7 @@ with gr.Blocks() as demo:
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# -------------------------------
|
| 152 |
-
#
|
| 153 |
# -------------------------------
|
| 154 |
if __name__ == "__main__":
|
| 155 |
demo.launch()
|
|
|
|
| 9 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 10 |
from docx import Document as DocxDocument
|
| 11 |
from transformers import pipeline
|
| 12 |
+
from huggingface_hub import login, snapshot_download
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
# -------------------------------
|
|
|
|
| 20 |
login(token=HF_TOKEN)
|
| 21 |
print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
|
| 22 |
else:
|
| 23 |
+
print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,Gemma-7B 可能無法下載")
|
| 24 |
|
| 25 |
+
# -------------------------------
|
| 26 |
+
# 3. 模型清單
|
| 27 |
+
# -------------------------------
|
| 28 |
+
MODEL_MAP = {
|
| 29 |
+
"Auto": None,
|
| 30 |
+
"Gemma-2B": "google/gemma-2b",
|
| 31 |
+
"Gemma-7B": "google/gemma-7b", # gated
|
| 32 |
+
"BTLM-3B-8K": "cerebras/btlm-3b-8k",
|
| 33 |
+
"Mistral-7B": "mistralai/Mistral-7B-v0.1"
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# -------------------------------
|
| 37 |
+
# 4. 預先下載模型到本地 ./models/
|
| 38 |
+
# -------------------------------
|
| 39 |
+
LOCAL_MODEL_DIRS = {}
|
| 40 |
+
for name, repo in MODEL_MAP.items():
|
| 41 |
+
if repo is None: # Auto 跳過
|
| 42 |
+
continue
|
| 43 |
+
try:
|
| 44 |
+
local_dir = f"./models/{repo.split('/')[-1]}"
|
| 45 |
+
if not os.path.exists(local_dir):
|
| 46 |
+
print(f"⬇️ 正在下載模型 {repo} ...")
|
| 47 |
+
snapshot_download(
|
| 48 |
+
repo_id=repo,
|
| 49 |
+
token=HF_TOKEN,
|
| 50 |
+
local_dir=local_dir
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
print(f"✅ 已存在模型 {repo} -> {local_dir}")
|
| 54 |
+
LOCAL_MODEL_DIRS[name] = local_dir
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"⚠️ 模型 {repo} 無法下載: {e}")
|
| 57 |
+
|
| 58 |
+
# -------------------------------
|
| 59 |
+
# 5. 建立或載入向量資料庫
|
| 60 |
+
# -------------------------------
|
| 61 |
TXT_FOLDER = "./out_texts"
|
| 62 |
DB_PATH = "./faiss_db"
|
| 63 |
os.makedirs(DB_PATH, exist_ok=True)
|
| 64 |
os.makedirs(TXT_FOLDER, exist_ok=True)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
| 66 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 67 |
embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
|
| 68 |
|
|
|
|
| 84 |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
| 85 |
|
| 86 |
# -------------------------------
|
| 87 |
+
# 6. 本地 pipeline
|
| 88 |
# -------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
_loaded_pipelines = {}
|
| 90 |
|
| 91 |
def get_pipeline(model_name):
|
| 92 |
if model_name not in _loaded_pipelines:
|
| 93 |
+
local_path = LOCAL_MODEL_DIRS.get(model_name)
|
| 94 |
+
if not local_path:
|
| 95 |
+
raise ValueError(f"❌ 模型 {model_name} 尚未下載")
|
| 96 |
+
print(f"🔄 正在載入本地模型 {model_name} from {local_path}")
|
| 97 |
generator = pipeline(
|
| 98 |
"text-generation",
|
| 99 |
+
model=local_path,
|
| 100 |
+
tokenizer=local_path,
|
| 101 |
+
device_map="auto"
|
|
|
|
| 102 |
)
|
| 103 |
_loaded_pipelines[model_name] = generator
|
| 104 |
return _loaded_pipelines[model_name]
|
|
|
|
| 112 |
return f"(生成失敗:{e})"
|
| 113 |
|
| 114 |
# -------------------------------
|
| 115 |
+
# 7. Auto 模式邏輯
|
| 116 |
# -------------------------------
|
| 117 |
def pick_model_auto(segments):
|
|
|
|
| 118 |
if segments <= 3:
|
| 119 |
return "Gemma-2B"
|
| 120 |
elif segments <= 6:
|
| 121 |
return "BTLM-3B-8K"
|
| 122 |
else:
|
| 123 |
+
return "Mistral-7B"
|
| 124 |
|
| 125 |
def generate_article_progress(query, model_name, segments=5):
|
| 126 |
docx_file = "/tmp/generated_article.docx"
|
| 127 |
doc = DocxDocument()
|
| 128 |
doc.add_heading(query, level=1)
|
| 129 |
|
|
|
|
| 130 |
if model_name == "Auto":
|
| 131 |
selected_model = pick_model_auto(int(segments))
|
| 132 |
else:
|
|
|
|
| 148 |
yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
|
| 149 |
|
| 150 |
# -------------------------------
|
| 151 |
+
# 8. Gradio 介面
|
| 152 |
# -------------------------------
|
| 153 |
with gr.Blocks() as demo:
|
| 154 |
+
gr.Markdown("# 佛教經論 RAG 系統 (本地模型)")
|
| 155 |
+
gr.Markdown("支援 Gemma / BTLM / Mistral,Auto 模式會自動選擇模型。")
|
| 156 |
|
| 157 |
query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
|
| 158 |
model_dropdown = gr.Dropdown(
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# -------------------------------
|
| 176 |
+
# 9. 啟動 Gradio
|
| 177 |
# -------------------------------
|
| 178 |
if __name__ == "__main__":
|
| 179 |
demo.launch()
|