CHUNYU0505 commited on
Commit
58ba158
·
verified ·
1 Parent(s): 6b1b850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -27,10 +27,10 @@ else:
27
  # -------------------------------
28
  MODEL_MAP = {
29
  "Auto": None,
30
- "Gemma-2B": "google/gemma-2b",
31
- "BTLM-3B-8K": "tiiuae/btlm-3b-8k-base",
32
- "DistilGPT2": "distilgpt2",
33
- "BART-Base": "facebook/bart-base"
34
  }
35
 
36
  # -------------------------------
@@ -68,9 +68,9 @@ def test_models():
68
  tokenizer=local_dir,
69
  device_map="cpu"
70
  )
71
- print(f"✅ 模型 {name} ({local_dir}) 可用")
72
  except Exception as e:
73
- print(f"❌ 模型 {name} ({local_dir}) 無法載入: {e}")
74
 
75
  test_models()
76
 
@@ -112,7 +112,7 @@ def get_pipeline(model_name):
112
  local_path = LOCAL_MODEL_DIRS.get(model_name)
113
  if not local_path:
114
  raise ValueError(f"❌ 模型 {model_name} 尚未下載")
115
- print(f"🔄 正在載入本地模型 {model_name} from {local_path}")
116
  generator = pipeline(
117
  "text-generation",
118
  model=local_path,
@@ -135,11 +135,13 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
135
  # -------------------------------
136
  def pick_model_auto(segments):
137
  if segments <= 3:
138
- return "Gemma-2B"
139
  elif segments <= 6:
140
- return "BTLM-3B-8K"
 
 
141
  else:
142
- return "BART-Base"
143
 
144
  def generate_article_progress(query, model_name, segments=5):
145
  docx_file = "/tmp/generated_article.docx"
 
27
  # -------------------------------
28
  MODEL_MAP = {
29
  "Auto": None,
30
+ "Gemma-2B": "google/gemma-2b", # gated,需要 Access repository
31
+ "BTLM-3B-8K": "cerebras/btlm-3b-8k-base", # 正確 repo
32
+ "DistilGPT2": "distilgpt2", # 小模型
33
+ "BART-Base": "facebook/bart-base" # 小模型
34
  }
35
 
36
  # -------------------------------
 
68
  tokenizer=local_dir,
69
  device_map="cpu"
70
  )
71
+ print(f"✅ 模型 {name} 可用")
72
  except Exception as e:
73
+ print(f"❌ 模型 {name} 無法載入: {e}")
74
 
75
  test_models()
76
 
 
112
  local_path = LOCAL_MODEL_DIRS.get(model_name)
113
  if not local_path:
114
  raise ValueError(f"❌ 模型 {model_name} 尚未下載")
115
+ print(f"🔄 正在載入模型 {model_name} from {local_path}")
116
  generator = pipeline(
117
  "text-generation",
118
  model=local_path,
 
135
  # -------------------------------
136
  def pick_model_auto(segments):
137
  if segments <= 3:
138
+ return "DistilGPT2" # 短文用最小模型,快
139
  elif segments <= 6:
140
+ return "Gemma-2B" # 中篇用 Gemma-2B
141
+ elif segments <= 8:
142
+ return "BTLM-3B-8K" # 長文用 BTLM
143
  else:
144
+ return "BART-Base" # 超長用 Bart-base
145
 
146
  def generate_article_progress(query, model_name, segments=5):
147
  docx_file = "/tmp/generated_article.docx"