CHUNYU0505 commited on
Commit
4cc15ed
·
verified ·
1 Parent(s): 5abe977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -1,54 +1,47 @@
1
  # app.py
2
- import os, glob
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
- from transformers import pipeline
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
- # 1. 模型清單(公開可用)
14
  # -------------------------------
15
  MODEL_MAP = {
16
  "Auto": None,
17
- "BTLM-3B-8K": "cerebras/btlm-3b-8k-base", # 需要 trust_remote_code=True
18
- "DistilGPT2": "distilgpt2", # 小模型
19
- "BART-Base": "facebook/bart-base" # 小模型
20
  }
21
 
22
- # -------------------------------
23
- # 2. Hugging Face 登入
24
- # -------------------------------
25
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
26
  if HF_TOKEN:
27
  login(token=HF_TOKEN)
28
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
29
- else:
30
- print("⚠️ 沒有 HUGGINGFACEHUB_API_TOKEN,下載速度可能受限")
31
 
32
  # -------------------------------
33
- # 3. 預先下載模型到 ./models/
34
  # -------------------------------
35
  LOCAL_MODEL_DIRS = {}
36
  for name, repo in MODEL_MAP.items():
37
- if repo is None:
38
  continue
39
  try:
40
  local_dir = f"./models/{repo.split('/')[-1]}"
41
  if not os.path.exists(local_dir):
42
  print(f"⬇️ 正在下載模型 {repo} ...")
43
  snapshot_download(repo_id=repo, token=HF_TOKEN, local_dir=local_dir)
44
- else:
45
- print(f"✅ 已存在模型 {repo} -> {local_dir}")
46
  LOCAL_MODEL_DIRS[name] = local_dir
47
  except Exception as e:
48
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
49
 
50
  # -------------------------------
51
- # 4. pipeline 載入(含 trust_remote_code)
52
  # -------------------------------
53
  _loaded_pipelines = {}
54
 
@@ -56,26 +49,38 @@ def get_pipeline(model_name):
56
  if model_name not in _loaded_pipelines:
57
  local_path = LOCAL_MODEL_DIRS.get(model_name)
58
  print(f"🔄 正在載入模型 {model_name} from {local_path}")
 
 
 
 
 
 
 
 
59
  generator = pipeline(
60
  "text-generation",
61
- model=local_path,
62
- tokenizer=local_path,
63
- device_map="cpu",
64
- trust_remote_code=True # <<<< 加這個才能跑 BTLM
65
  )
66
  _loaded_pipelines[model_name] = generator
67
  return _loaded_pipelines[model_name]
68
 
69
- def call_local_inference(model_name, prompt, max_new_tokens=512):
70
  try:
71
  generator = get_pipeline(model_name)
72
- outputs = generator(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7)
 
 
 
 
 
73
  return outputs[0]["generated_text"]
74
  except Exception as e:
75
  return f"(生成失敗:{e})"
76
 
77
  # -------------------------------
78
- # 5. Auto 模式邏輯
79
  # -------------------------------
80
  def pick_model_auto(segments):
81
  if segments <= 3:
@@ -107,7 +112,7 @@ def generate_article_progress(query, model_name, segments=5):
107
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
108
 
109
  # -------------------------------
110
- # 6. Gradio 介面
111
  # -------------------------------
112
  with gr.Blocks() as demo:
113
  gr.Markdown("# 佛教經論 RAG 系統 (CPU 免費版)")
 
1
  # app.py
2
+ import os, glob, torch
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from docx import Document as DocxDocument
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from huggingface_hub import login, snapshot_download
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型清單
14
  # -------------------------------
15
  MODEL_MAP = {
16
  "Auto": None,
17
+ "BTLM-3B-8K": "cerebras/btlm-3b-8k-base",
18
+ "DistilGPT2": "distilgpt2",
19
+ "BART-Base": "facebook/bart-base"
20
  }
21
 
 
 
 
22
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
23
  if HF_TOKEN:
24
  login(token=HF_TOKEN)
25
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
 
 
26
 
27
  # -------------------------------
28
+ # 2. 預先下載模型
29
  # -------------------------------
30
  LOCAL_MODEL_DIRS = {}
31
  for name, repo in MODEL_MAP.items():
32
+ if repo is None:
33
  continue
34
  try:
35
  local_dir = f"./models/{repo.split('/')[-1]}"
36
  if not os.path.exists(local_dir):
37
  print(f"⬇️ 正在下載模型 {repo} ...")
38
  snapshot_download(repo_id=repo, token=HF_TOKEN, local_dir=local_dir)
 
 
39
  LOCAL_MODEL_DIRS[name] = local_dir
40
  except Exception as e:
41
  print(f"⚠️ 模型 {repo} 無法下載: {e}")
42
 
43
  # -------------------------------
44
+ # 3. pipeline 載入
45
  # -------------------------------
46
  _loaded_pipelines = {}
47
 
 
49
  if model_name not in _loaded_pipelines:
50
  local_path = LOCAL_MODEL_DIRS.get(model_name)
51
  print(f"🔄 正在載入模型 {model_name} from {local_path}")
52
+
53
+ if model_name == "BTLM-3B-8K":
54
+ tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
55
+ model = AutoModelForCausalLM.from_pretrained(local_path, trust_remote_code=True)
56
+ else:
57
+ tokenizer = AutoTokenizer.from_pretrained(local_path)
58
+ model = AutoModelForCausalLM.from_pretrained(local_path)
59
+
60
  generator = pipeline(
61
  "text-generation",
62
+ model=model,
63
+ tokenizer=tokenizer,
64
+ device= -1 # 強制 CPU
 
65
  )
66
  _loaded_pipelines[model_name] = generator
67
  return _loaded_pipelines[model_name]
68
 
69
+ def call_local_inference(model_name, prompt, max_new_tokens=256):
70
  try:
71
  generator = get_pipeline(model_name)
72
+ outputs = generator(
73
+ prompt,
74
+ max_new_tokens=max_new_tokens,
75
+ do_sample=True,
76
+ temperature=0.7
77
+ )
78
  return outputs[0]["generated_text"]
79
  except Exception as e:
80
  return f"(生成失敗:{e})"
81
 
82
  # -------------------------------
83
+ # 4. Auto 模式
84
  # -------------------------------
85
  def pick_model_auto(segments):
86
  if segments <= 3:
 
112
  yield "\n\n".join(all_text), docx_file, f"本次使用模型:{selected_model}"
113
 
114
  # -------------------------------
115
+ # 5. Gradio 介面
116
  # -------------------------------
117
  with gr.Blocks() as demo:
118
  gr.Markdown("# 佛教經論 RAG 系統 (CPU 免費版)")