Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,10 +27,10 @@ else:
|
|
| 27 |
# -------------------------------
|
| 28 |
MODEL_MAP = {
|
| 29 |
"Auto": None,
|
| 30 |
-
"Gemma-2B": "google/gemma-2b",
|
| 31 |
-
"BTLM-3B-8K": "
|
| 32 |
-
"DistilGPT2": "distilgpt2",
|
| 33 |
-
"BART-Base": "facebook/bart-base"
|
| 34 |
}
|
| 35 |
|
| 36 |
# -------------------------------
|
|
@@ -68,9 +68,9 @@ def test_models():
|
|
| 68 |
tokenizer=local_dir,
|
| 69 |
device_map="cpu"
|
| 70 |
)
|
| 71 |
-
print(f"✅ 模型 {name}
|
| 72 |
except Exception as e:
|
| 73 |
-
print(f"❌ 模型 {name}
|
| 74 |
|
| 75 |
test_models()
|
| 76 |
|
|
@@ -112,7 +112,7 @@ def get_pipeline(model_name):
|
|
| 112 |
local_path = LOCAL_MODEL_DIRS.get(model_name)
|
| 113 |
if not local_path:
|
| 114 |
raise ValueError(f"❌ 模型 {model_name} 尚未下載")
|
| 115 |
-
print(f"🔄
|
| 116 |
generator = pipeline(
|
| 117 |
"text-generation",
|
| 118 |
model=local_path,
|
|
@@ -135,11 +135,13 @@ def call_local_inference(model_name, prompt, max_new_tokens=512):
|
|
| 135 |
# -------------------------------
|
| 136 |
def pick_model_auto(segments):
|
| 137 |
if segments <= 3:
|
| 138 |
-
return "
|
| 139 |
elif segments <= 6:
|
| 140 |
-
return "
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
-
return "BART-Base"
|
| 143 |
|
| 144 |
def generate_article_progress(query, model_name, segments=5):
|
| 145 |
docx_file = "/tmp/generated_article.docx"
|
|
|
|
| 27 |
# -------------------------------
|
| 28 |
MODEL_MAP = {
|
| 29 |
"Auto": None,
|
| 30 |
+
"Gemma-2B": "google/gemma-2b", # gated,需要 Access repository
|
| 31 |
+
"BTLM-3B-8K": "cerebras/btlm-3b-8k-base", # 正確 repo
|
| 32 |
+
"DistilGPT2": "distilgpt2", # 小模型
|
| 33 |
+
"BART-Base": "facebook/bart-base" # 小模型
|
| 34 |
}
|
| 35 |
|
| 36 |
# -------------------------------
|
|
|
|
| 68 |
tokenizer=local_dir,
|
| 69 |
device_map="cpu"
|
| 70 |
)
|
| 71 |
+
print(f"✅ 模型 {name} 可用")
|
| 72 |
except Exception as e:
|
| 73 |
+
print(f"❌ 模型 {name} 無法載入: {e}")
|
| 74 |
|
| 75 |
test_models()
|
| 76 |
|
|
|
|
| 112 |
local_path = LOCAL_MODEL_DIRS.get(model_name)
|
| 113 |
if not local_path:
|
| 114 |
raise ValueError(f"❌ 模型 {model_name} 尚未下載")
|
| 115 |
+
print(f"🔄 正在載入模型 {model_name} from {local_path}")
|
| 116 |
generator = pipeline(
|
| 117 |
"text-generation",
|
| 118 |
model=local_path,
|
|
|
|
| 135 |
# -------------------------------
|
| 136 |
def pick_model_auto(segments):
|
| 137 |
if segments <= 3:
|
| 138 |
+
return "DistilGPT2" # 短文用最小模型,快
|
| 139 |
elif segments <= 6:
|
| 140 |
+
return "Gemma-2B" # 中篇用 Gemma-2B
|
| 141 |
+
elif segments <= 8:
|
| 142 |
+
return "BTLM-3B-8K" # 長文用 BTLM
|
| 143 |
else:
|
| 144 |
+
return "BART-Base" # 超長用 Bart-base
|
| 145 |
|
| 146 |
def generate_article_progress(query, model_name, segments=5):
|
| 147 |
docx_file = "/tmp/generated_article.docx"
|