Add direct streaming zip download button
Browse files
app.py
CHANGED
|
@@ -3,8 +3,7 @@ import re
|
|
| 3 |
import queue
|
| 4 |
import threading
|
| 5 |
import zipfile
|
| 6 |
-
from urllib.parse import quote
|
| 7 |
-
import html
|
| 8 |
|
| 9 |
import requests
|
| 10 |
import gradio as gr
|
|
@@ -61,7 +60,11 @@ def get_repo_resolve_prefix(repo_type: str, repo: str) -> str:
|
|
| 61 |
|
| 62 |
def list_repo_files(repo: str, revision: str, repo_type: str):
|
| 63 |
api_path = get_repo_api_path(repo_type, repo)
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
response = requests.get(
|
| 67 |
url,
|
|
@@ -76,14 +79,17 @@ def list_repo_files(repo: str, revision: str, repo_type: str):
|
|
| 76 |
|
| 77 |
items = response.json()
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
|
|
@@ -96,17 +102,19 @@ def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -
|
|
| 96 |
)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
class QueueWriter:
|
| 100 |
"""
|
| 101 |
-
给 zipfile 使用的
|
| 102 |
|
| 103 |
-
zipfile
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
| 107 |
-
不落盘。
|
| 108 |
-
不先生成完整 zip。
|
| 109 |
-
一边从 HF 拉文件,一边写 zip,一边传给浏览器。
|
| 110 |
"""
|
| 111 |
|
| 112 |
def __init__(self, output_queue: queue.Queue):
|
|
@@ -139,6 +147,7 @@ def stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list):
|
|
| 139 |
) as zip_file:
|
| 140 |
for file in files:
|
| 141 |
file_path = file["path"]
|
|
|
|
| 142 |
print(f"正在写入 ZIP:{file_path}", flush=True)
|
| 143 |
|
| 144 |
url = get_file_download_url(
|
|
@@ -197,17 +206,12 @@ def stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list):
|
|
| 197 |
yield item
|
| 198 |
|
| 199 |
|
| 200 |
-
def safe_filename(value: str) -> str:
|
| 201 |
-
value = re.sub(r"[^\w.-]+", "_", value)
|
| 202 |
-
return value.strip("_") or "repo"
|
| 203 |
-
|
| 204 |
-
|
| 205 |
api_app = FastAPI()
|
| 206 |
|
| 207 |
|
| 208 |
@api_app.get("/api/hf-model-zip")
|
| 209 |
def hf_model_zip(
|
| 210 |
-
repo: str = Query(..., description="例如
|
| 211 |
revision: str = Query("main"),
|
| 212 |
repoType: str = Query("model"),
|
| 213 |
):
|
|
@@ -224,7 +228,7 @@ def hf_model_zip(
|
|
| 224 |
if not repo or "/" not in repo:
|
| 225 |
raise HTTPException(
|
| 226 |
status_code=400,
|
| 227 |
-
detail="repo 参数格式错误,应为 owner/name,例如
|
| 228 |
)
|
| 229 |
|
| 230 |
print(f"准备下载仓库:{repo}", flush=True)
|
|
@@ -258,51 +262,54 @@ def hf_model_zip(
|
|
| 258 |
)
|
| 259 |
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
if
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
{
|
| 270 |
-
"repo": repo,
|
| 271 |
-
"revision": revision,
|
| 272 |
-
"repoType": repo_type,
|
| 273 |
-
}
|
| 274 |
-
)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
|
| 295 |
with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
|
| 296 |
gr.Markdown("# Hugging Face ZIP 流式下载器")
|
| 297 |
gr.Markdown(
|
| 298 |
-
"输入 Hugging Face 模型、数据集或 Space 仓库名,
|
| 299 |
)
|
| 300 |
|
| 301 |
with gr.Row():
|
| 302 |
repo_input = gr.Textbox(
|
| 303 |
label="仓库名",
|
| 304 |
value="sshleifer/tiny-gpt2",
|
| 305 |
-
placeholder="例如:Qwen/
|
| 306 |
)
|
| 307 |
|
| 308 |
revision_input = gr.Textbox(
|
|
@@ -316,13 +323,25 @@ with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
|
|
| 316 |
value="model",
|
| 317 |
)
|
| 318 |
|
| 319 |
-
|
| 320 |
-
output = gr.HTML(label="下载链接")
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
inputs=[repo_input, revision_input, repo_type_input],
|
| 325 |
-
outputs=
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
|
| 328 |
|
|
@@ -333,10 +352,15 @@ app = gr.mount_gradio_app(
|
|
| 333 |
ssr_mode=False,
|
| 334 |
)
|
| 335 |
|
|
|
|
| 336 |
if __name__ == "__main__":
|
| 337 |
import uvicorn
|
| 338 |
|
| 339 |
-
port = int(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
uvicorn.run(
|
| 342 |
app,
|
|
|
|
| 3 |
import queue
|
| 4 |
import threading
|
| 5 |
import zipfile
|
| 6 |
+
from urllib.parse import quote
|
|
|
|
| 7 |
|
| 8 |
import requests
|
| 9 |
import gradio as gr
|
|
|
|
| 60 |
|
| 61 |
def list_repo_files(repo: str, revision: str, repo_type: str):
|
| 62 |
api_path = get_repo_api_path(repo_type, repo)
|
| 63 |
+
|
| 64 |
+
url = (
|
| 65 |
+
f"{HF_BASE}{api_path}"
|
| 66 |
+
f"/tree/{quote(revision, safe='')}?recursive=1"
|
| 67 |
+
)
|
| 68 |
|
| 69 |
response = requests.get(
|
| 70 |
url,
|
|
|
|
| 79 |
|
| 80 |
items = response.json()
|
| 81 |
|
| 82 |
+
files = []
|
| 83 |
+
for item in items:
|
| 84 |
+
if item.get("type") == "file":
|
| 85 |
+
files.append(
|
| 86 |
+
{
|
| 87 |
+
"path": item.get("path"),
|
| 88 |
+
"size": item.get("size"),
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return files
|
| 93 |
|
| 94 |
|
| 95 |
def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
|
| 105 |
+
def safe_filename(value: str) -> str:
|
| 106 |
+
value = re.sub(r"[^\w.-]+", "_", value)
|
| 107 |
+
return value.strip("_") or "repo"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
class QueueWriter:
|
| 111 |
"""
|
| 112 |
+
给 zipfile 使用的流式 writer。
|
| 113 |
|
| 114 |
+
zipfile 写出来的每一段 bytes 都会进入 queue,
|
| 115 |
+
FastAPI StreamingResponse 再把 queue 里的 bytes 发送给浏览器。
|
| 116 |
|
| 117 |
+
所以这里不会先生成完整 zip 文件。
|
|
|
|
|
|
|
|
|
|
| 118 |
"""
|
| 119 |
|
| 120 |
def __init__(self, output_queue: queue.Queue):
|
|
|
|
| 147 |
) as zip_file:
|
| 148 |
for file in files:
|
| 149 |
file_path = file["path"]
|
| 150 |
+
|
| 151 |
print(f"正在写入 ZIP:{file_path}", flush=True)
|
| 152 |
|
| 153 |
url = get_file_download_url(
|
|
|
|
| 206 |
yield item
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
api_app = FastAPI()
|
| 210 |
|
| 211 |
|
| 212 |
@api_app.get("/api/hf-model-zip")
|
| 213 |
def hf_model_zip(
|
| 214 |
+
repo: str = Query(..., description="例如 Qwen/Qwen3.6-35B-A3B"),
|
| 215 |
revision: str = Query("main"),
|
| 216 |
repoType: str = Query("model"),
|
| 217 |
):
|
|
|
|
| 228 |
if not repo or "/" not in repo:
|
| 229 |
raise HTTPException(
|
| 230 |
status_code=400,
|
| 231 |
+
detail="repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B",
|
| 232 |
)
|
| 233 |
|
| 234 |
print(f"准备下载仓库:{repo}", flush=True)
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
|
| 265 |
+
download_js = """
|
| 266 |
+
(repo, revision, repoType) => {
|
| 267 |
+
repo = (repo || "").trim();
|
| 268 |
+
revision = (revision || "main").trim();
|
| 269 |
+
repoType = (repoType || "model").trim();
|
| 270 |
|
| 271 |
+
if (!repo || !repo.includes("/")) {
|
| 272 |
+
alert("repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B");
|
| 273 |
+
return [repo, revision, repoType];
|
| 274 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
const params = new URLSearchParams({
|
| 277 |
+
repo: repo,
|
| 278 |
+
revision: revision,
|
| 279 |
+
repoType: repoType
|
| 280 |
+
});
|
| 281 |
+
|
| 282 |
+
const url = `/api/hf-model-zip?${params.toString()}`;
|
| 283 |
+
|
| 284 |
+
const a = document.createElement("a");
|
| 285 |
+
a.href = url;
|
| 286 |
+
a.download = "";
|
| 287 |
+
a.style.display = "none";
|
| 288 |
+
|
| 289 |
+
document.body.appendChild(a);
|
| 290 |
+
a.click();
|
| 291 |
+
document.body.removeChild(a);
|
| 292 |
+
|
| 293 |
+
return [repo, revision, repoType];
|
| 294 |
+
}
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def noop(repo, revision, repo_type):
|
| 299 |
+
return None
|
| 300 |
|
| 301 |
|
| 302 |
with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
|
| 303 |
gr.Markdown("# Hugging Face ZIP 流式下载器")
|
| 304 |
gr.Markdown(
|
| 305 |
+
"输入 Hugging Face 模型、数据集或 Space 仓库名,点击按钮后直接开始流式下载 ZIP。"
|
| 306 |
)
|
| 307 |
|
| 308 |
with gr.Row():
|
| 309 |
repo_input = gr.Textbox(
|
| 310 |
label="仓库名",
|
| 311 |
value="sshleifer/tiny-gpt2",
|
| 312 |
+
placeholder="例如:Qwen/Qwen3.6-35B-A3B",
|
| 313 |
)
|
| 314 |
|
| 315 |
revision_input = gr.Textbox(
|
|
|
|
| 323 |
value="model",
|
| 324 |
)
|
| 325 |
|
| 326 |
+
download_button = gr.Button("下载 ZIP", variant="primary")
|
|
|
|
| 327 |
|
| 328 |
+
gr.Markdown(
|
| 329 |
+
"""
|
| 330 |
+
说明:
|
| 331 |
+
|
| 332 |
+
- 点击按钮后会直接触发浏览器下载。
|
| 333 |
+
- 不会先把完整 ZIP 保存到 Space。
|
| 334 |
+
- 流量路径是:Hugging Face → Space → 客户浏览器。
|
| 335 |
+
- 私有模型或 gated 模型需要在 Space Secrets 里配置 `HF_TOKEN`。
|
| 336 |
+
"""
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
download_button.click(
|
| 340 |
+
fn=noop,
|
| 341 |
inputs=[repo_input, revision_input, repo_type_input],
|
| 342 |
+
outputs=[],
|
| 343 |
+
js=download_js,
|
| 344 |
+
queue=False,
|
| 345 |
)
|
| 346 |
|
| 347 |
|
|
|
|
| 352 |
ssr_mode=False,
|
| 353 |
)
|
| 354 |
|
| 355 |
+
|
| 356 |
if __name__ == "__main__":
|
| 357 |
import uvicorn
|
| 358 |
|
| 359 |
+
port = int(
|
| 360 |
+
os.environ.get("PORT")
|
| 361 |
+
or os.environ.get("GRADIO_SERVER_PORT")
|
| 362 |
+
or "7860"
|
| 363 |
+
)
|
| 364 |
|
| 365 |
uvicorn.run(
|
| 366 |
app,
|