lgccccc commited on
Commit
fcbbb3d
·
1 Parent(s): 053ebc2

Add direct streaming zip download button

Browse files
Files changed (1) hide show
  1. app.py +87 -63
app.py CHANGED
@@ -3,8 +3,7 @@ import re
3
  import queue
4
  import threading
5
  import zipfile
6
- from urllib.parse import quote, urlencode
7
- import html
8
 
9
  import requests
10
  import gradio as gr
@@ -61,7 +60,11 @@ def get_repo_resolve_prefix(repo_type: str, repo: str) -> str:
61
 
62
  def list_repo_files(repo: str, revision: str, repo_type: str):
63
  api_path = get_repo_api_path(repo_type, repo)
64
- url = f"{HF_BASE}{api_path}/tree/{quote(revision, safe='')}?recursive=1"
 
 
 
 
65
 
66
  response = requests.get(
67
  url,
@@ -76,14 +79,17 @@ def list_repo_files(repo: str, revision: str, repo_type: str):
76
 
77
  items = response.json()
78
 
79
- return [
80
- {
81
- "path": item["path"],
82
- "size": item.get("size"),
83
- }
84
- for item in items
85
- if item.get("type") == "file"
86
- ]
 
 
 
87
 
88
 
89
  def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
@@ -96,17 +102,19 @@ def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -
96
  )
97
 
98
 
 
 
 
 
 
99
  class QueueWriter:
100
  """
101
- 给 zipfile 使用的“假文件对象”
102
 
103
- zipfile 往这里 write 的每一段 bytes,
104
- 都会立刻放进队列,然后 StreamingResponse 再把发给浏览器。
105
 
106
- 关键点:
107
- 不落盘。
108
- 不先生成完整 zip。
109
- 一边从 HF 拉文件,一边写 zip,一边传给浏览器。
110
  """
111
 
112
  def __init__(self, output_queue: queue.Queue):
@@ -139,6 +147,7 @@ def stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list):
139
  ) as zip_file:
140
  for file in files:
141
  file_path = file["path"]
 
142
  print(f"正在写入 ZIP:{file_path}", flush=True)
143
 
144
  url = get_file_download_url(
@@ -197,17 +206,12 @@ def stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list):
197
  yield item
198
 
199
 
200
- def safe_filename(value: str) -> str:
201
- value = re.sub(r"[^\w.-]+", "_", value)
202
- return value.strip("_") or "repo"
203
-
204
-
205
  api_app = FastAPI()
206
 
207
 
208
  @api_app.get("/api/hf-model-zip")
209
  def hf_model_zip(
210
- repo: str = Query(..., description="例如 sshleifer/tiny-gpt2"),
211
  revision: str = Query("main"),
212
  repoType: str = Query("model"),
213
  ):
@@ -224,7 +228,7 @@ def hf_model_zip(
224
  if not repo or "/" not in repo:
225
  raise HTTPException(
226
  status_code=400,
227
- detail="repo 参数格式错误,应为 owner/name,例如 sshleifer/tiny-gpt2",
228
  )
229
 
230
  print(f"准备下载仓库:{repo}", flush=True)
@@ -258,51 +262,54 @@ def hf_model_zip(
258
  )
259
 
260
 
261
- def build_download_link(repo: str, revision: str, repo_type: str):
262
- repo = repo.strip()
263
- revision = revision.strip() or "main"
 
 
264
 
265
- if not repo or "/" not in repo:
266
- return "<p style='color:red'>repo 参数格式错误,应为 owner/name,例如 sshleifer/tiny-gpt2</p>"
267
-
268
- params = urlencode(
269
- {
270
- "repo": repo,
271
- "revision": revision,
272
- "repoType": repo_type,
273
- }
274
- )
275
 
276
- url = f"/api/hf-model-zip?{params}"
277
-
278
- return f"""
279
- <div style="padding: 12px; border: 1px solid #ddd; border-radius: 8px;">
280
- <p><b>下载链接已生成:</b></p>
281
- <p>
282
- <a href="{html.escape(url)}" target="_blank" download>
283
- 点击这里开始流式下载 ZIP
284
- </a>
285
- </p>
286
- <p style="font-size: 13px; color: #666;">
287
- 说明:这个链接不会先把完整 ZIP 保存到 Space。
288
- 它会将 Hugging Face 文件流实时写入 ZIP 流,并直接传给浏览器。
289
- </p>
290
- <code>{html.escape(url)}</code>
291
- </div>
292
- """
 
 
 
 
 
 
 
293
 
294
 
295
  with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
296
  gr.Markdown("# Hugging Face ZIP 流式下载器")
297
  gr.Markdown(
298
- "输入 Hugging Face 模型、数据集或 Space 仓库名,生成一个流式 ZIP 下载链接。"
299
  )
300
 
301
  with gr.Row():
302
  repo_input = gr.Textbox(
303
  label="仓库名",
304
  value="sshleifer/tiny-gpt2",
305
- placeholder="例如:Qwen/Qwen2.5-0.5B-Instruct",
306
  )
307
 
308
  revision_input = gr.Textbox(
@@ -316,13 +323,25 @@ with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
316
  value="model",
317
  )
318
 
319
- button = gr.Button("生成下载链接")
320
- output = gr.HTML(label="下载链接")
321
 
322
- button.click(
323
- fn=build_download_link,
 
 
 
 
 
 
 
 
 
 
 
324
  inputs=[repo_input, revision_input, repo_type_input],
325
- outputs=output,
 
 
326
  )
327
 
328
 
@@ -333,10 +352,15 @@ app = gr.mount_gradio_app(
333
  ssr_mode=False,
334
  )
335
 
 
336
  if __name__ == "__main__":
337
  import uvicorn
338
 
339
- port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860")))
 
 
 
 
340
 
341
  uvicorn.run(
342
  app,
 
3
  import queue
4
  import threading
5
  import zipfile
6
+ from urllib.parse import quote
 
7
 
8
  import requests
9
  import gradio as gr
 
60
 
61
  def list_repo_files(repo: str, revision: str, repo_type: str):
62
  api_path = get_repo_api_path(repo_type, repo)
63
+
64
+ url = (
65
+ f"{HF_BASE}{api_path}"
66
+ f"/tree/{quote(revision, safe='')}?recursive=1"
67
+ )
68
 
69
  response = requests.get(
70
  url,
 
79
 
80
  items = response.json()
81
 
82
+ files = []
83
+ for item in items:
84
+ if item.get("type") == "file":
85
+ files.append(
86
+ {
87
+ "path": item.get("path"),
88
+ "size": item.get("size"),
89
+ }
90
+ )
91
+
92
+ return files
93
 
94
 
95
  def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
 
102
  )
103
 
104
 
105
+ def safe_filename(value: str) -> str:
106
+ value = re.sub(r"[^\w.-]+", "_", value)
107
+ return value.strip("_") or "repo"
108
+
109
+
110
  class QueueWriter:
111
  """
112
+ 给 zipfile 使用的流式 writer
113
 
114
+ zipfile 写出来的每一段 bytes 都会进入 queue
115
+ FastAPI StreamingResponse 再把 queue 里的 bytes 给浏览器。
116
 
117
+ 所以这里不会先生成完整 zip 文件。
 
 
 
118
  """
119
 
120
  def __init__(self, output_queue: queue.Queue):
 
147
  ) as zip_file:
148
  for file in files:
149
  file_path = file["path"]
150
+
151
  print(f"正在写入 ZIP:{file_path}", flush=True)
152
 
153
  url = get_file_download_url(
 
206
  yield item
207
 
208
 
 
 
 
 
 
209
  api_app = FastAPI()
210
 
211
 
212
  @api_app.get("/api/hf-model-zip")
213
  def hf_model_zip(
214
+ repo: str = Query(..., description="例如 Qwen/Qwen3.6-35B-A3B"),
215
  revision: str = Query("main"),
216
  repoType: str = Query("model"),
217
  ):
 
228
  if not repo or "/" not in repo:
229
  raise HTTPException(
230
  status_code=400,
231
+ detail="repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B",
232
  )
233
 
234
  print(f"准备下载仓库:{repo}", flush=True)
 
262
  )
263
 
264
 
265
+ download_js = """
266
+ (repo, revision, repoType) => {
267
+ repo = (repo || "").trim();
268
+ revision = (revision || "main").trim();
269
+ repoType = (repoType || "model").trim();
270
 
271
+ if (!repo || !repo.includes("/")) {
272
+ alert("repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B");
273
+ return [repo, revision, repoType];
274
+ }
 
 
 
 
 
 
275
 
276
+ const params = new URLSearchParams({
277
+ repo: repo,
278
+ revision: revision,
279
+ repoType: repoType
280
+ });
281
+
282
+ const url = `/api/hf-model-zip?${params.toString()}`;
283
+
284
+ const a = document.createElement("a");
285
+ a.href = url;
286
+ a.download = "";
287
+ a.style.display = "none";
288
+
289
+ document.body.appendChild(a);
290
+ a.click();
291
+ document.body.removeChild(a);
292
+
293
+ return [repo, revision, repoType];
294
+ }
295
+ """
296
+
297
+
298
+ def noop(repo, revision, repo_type):
299
+ return None
300
 
301
 
302
  with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo:
303
  gr.Markdown("# Hugging Face ZIP 流式下载器")
304
  gr.Markdown(
305
+ "输入 Hugging Face 模型、数据集或 Space 仓库名,点击按钮后直接开始流式下载 ZIP。"
306
  )
307
 
308
  with gr.Row():
309
  repo_input = gr.Textbox(
310
  label="仓库名",
311
  value="sshleifer/tiny-gpt2",
312
+ placeholder="例如:Qwen/Qwen3.6-35B-A3B",
313
  )
314
 
315
  revision_input = gr.Textbox(
 
323
  value="model",
324
  )
325
 
326
+ download_button = gr.Button("下载 ZIP", variant="primary")
 
327
 
328
+ gr.Markdown(
329
+ """
330
+ 说明:
331
+
332
+ - 点击按钮后会直接触发浏览器下载。
333
+ - 不会先把完整 ZIP 保存到 Space。
334
+ - 流量路径是:Hugging Face → Space → 客户浏览器。
335
+ - 私有模型或 gated 模型需要在 Space Secrets 里配置 `HF_TOKEN`。
336
+ """
337
+ )
338
+
339
+ download_button.click(
340
+ fn=noop,
341
  inputs=[repo_input, revision_input, repo_type_input],
342
+ outputs=[],
343
+ js=download_js,
344
+ queue=False,
345
  )
346
 
347
 
 
352
  ssr_mode=False,
353
  )
354
 
355
+
356
  if __name__ == "__main__":
357
  import uvicorn
358
 
359
+ port = int(
360
+ os.environ.get("PORT")
361
+ or os.environ.get("GRADIO_SERVER_PORT")
362
+ or "7860"
363
+ )
364
 
365
  uvicorn.run(
366
  app,