| import os |
| import re |
| import queue |
| import threading |
| import zipfile |
| from urllib.parse import quote |
|
|
| import requests |
| import gradio as gr |
| from fastapi import FastAPI, HTTPException, Query |
| from fastapi.responses import StreamingResponse |
|
|
|
|
| HF_BASE = os.environ.get("HF_ENDPOINT", "https://huggingface.co") |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| def get_headers(): |
| headers = { |
| "User-Agent": "hf-zip-stream-gradio-space/1.0", |
| } |
|
|
| if HF_TOKEN: |
| headers["Authorization"] = f"Bearer {HF_TOKEN}" |
|
|
| return headers |
|
|
|
|
| def encode_repo(repo: str) -> str: |
| return "/".join(quote(part, safe="") for part in repo.split("/")) |
|
|
|
|
| def encode_path(path: str) -> str: |
| return "/".join(quote(part, safe="") for part in path.split("/")) |
|
|
|
|
| def get_repo_api_path(repo_type: str, repo: str) -> str: |
| encoded_repo = encode_repo(repo) |
|
|
| if repo_type == "dataset": |
| return f"/api/datasets/{encoded_repo}" |
|
|
| if repo_type == "space": |
| return f"/api/spaces/{encoded_repo}" |
|
|
| return f"/api/models/{encoded_repo}" |
|
|
|
|
| def get_repo_resolve_prefix(repo_type: str, repo: str) -> str: |
| encoded_repo = encode_repo(repo) |
|
|
| if repo_type == "dataset": |
| return f"/datasets/{encoded_repo}" |
|
|
| if repo_type == "space": |
| return f"/spaces/{encoded_repo}" |
|
|
| return f"/{encoded_repo}" |
|
|
|
|
| def list_repo_files(repo: str, revision: str, repo_type: str): |
| api_path = get_repo_api_path(repo_type, repo) |
|
|
| url = ( |
| f"{HF_BASE}{api_path}" |
| f"/tree/{quote(revision, safe='')}?recursive=1" |
| ) |
|
|
| response = requests.get( |
| url, |
| headers=get_headers(), |
| timeout=60, |
| ) |
|
|
| if not response.ok: |
| raise RuntimeError( |
| f"获取文件列表失败:{response.status_code} {response.text[:500]}" |
| ) |
|
|
| items = response.json() |
|
|
| files = [] |
| for item in items: |
| if item.get("type") == "file": |
| files.append( |
| { |
| "path": item.get("path"), |
| "size": item.get("size"), |
| } |
| ) |
|
|
| return files |
|
|
|
|
| def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str: |
| prefix = get_repo_resolve_prefix(repo_type, repo) |
|
|
| return ( |
| f"{HF_BASE}{prefix}" |
| f"/resolve/{quote(revision, safe='')}" |
| f"/{encode_path(path)}" |
| ) |
|
|
|
|
| def safe_filename(value: str) -> str: |
| value = re.sub(r"[^\w.-]+", "_", value) |
| return value.strip("_") or "repo" |
|
|
|
|
| class QueueWriter: |
| """ |
| 给 zipfile 使用的流式 writer。 |
| |
| zipfile 写出来的每一段 bytes 都会进入 queue, |
| FastAPI StreamingResponse 再把 queue 里的 bytes 发送给浏览器。 |
| |
| 所以这里不会先生成完整 zip 文件。 |
| """ |
|
|
| def __init__(self, output_queue: queue.Queue): |
| self.output_queue = output_queue |
|
|
| def write(self, data): |
| if data: |
| self.output_queue.put(bytes(data)) |
| return len(data) |
|
|
| def flush(self): |
| pass |
|
|
|
|
| _SENTINEL = object() |
|
|
|
|
| def stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list): |
| output_queue = queue.Queue(maxsize=16) |
|
|
| def producer(): |
| try: |
| writer = QueueWriter(output_queue) |
|
|
| with zipfile.ZipFile( |
| writer, |
| mode="w", |
| compression=zipfile.ZIP_STORED, |
| allowZip64=True, |
| ) as zip_file: |
| for file in files: |
| file_path = file["path"] |
|
|
| print(f"正在写入 ZIP:{file_path}", flush=True) |
|
|
| url = get_file_download_url( |
| repo=repo, |
| repo_type=repo_type, |
| revision=revision, |
| path=file_path, |
| ) |
|
|
| hf_response = requests.get( |
| url, |
| headers=get_headers(), |
| stream=True, |
| timeout=(20, 300), |
| ) |
|
|
| if not hf_response.ok: |
| raise RuntimeError( |
| f"下载文件失败:{file_path},状态码:{hf_response.status_code}" |
| ) |
|
|
| zip_info = zipfile.ZipInfo(filename=file_path) |
| zip_info.compress_type = zipfile.ZIP_STORED |
|
|
| if file.get("size") is not None: |
| zip_info.file_size = int(file["size"]) |
|
|
| with zip_file.open(zip_info, mode="w", force_zip64=True) as zip_entry: |
| for chunk in hf_response.iter_content(chunk_size=1024 * 1024): |
| if chunk: |
| zip_entry.write(chunk) |
|
|
| hf_response.close() |
|
|
| print("ZIP 流式传输完成", flush=True) |
|
|
| except Exception as error: |
| print(f"ZIP 处理失败:{error}", flush=True) |
| output_queue.put(error) |
|
|
| finally: |
| output_queue.put(_SENTINEL) |
|
|
| thread = threading.Thread(target=producer, daemon=True) |
| thread.start() |
|
|
| while True: |
| item = output_queue.get() |
|
|
| if item is _SENTINEL: |
| break |
|
|
| if isinstance(item, Exception): |
| raise item |
|
|
| yield item |
|
|
|
|
| api_app = FastAPI() |
|
|
|
|
| @api_app.get("/api/hf-model-zip") |
| def hf_model_zip( |
| repo: str = Query(..., description="例如 Qwen/Qwen3.6-35B-A3B"), |
| revision: str = Query("main"), |
| repoType: str = Query("model"), |
| ): |
| repo = repo.strip() |
| revision = revision.strip() or "main" |
| repo_type = repoType.strip().lower() or "model" |
|
|
| if repo_type not in {"model", "dataset", "space"}: |
| raise HTTPException( |
| status_code=400, |
| detail="repoType 必须是 model、dataset 或 space", |
| ) |
|
|
| if not repo or "/" not in repo: |
| raise HTTPException( |
| status_code=400, |
| detail="repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B", |
| ) |
|
|
| print(f"准备下载仓库:{repo}", flush=True) |
| print(f"revision:{revision}", flush=True) |
| print(f"repoType:{repo_type}", flush=True) |
|
|
| try: |
| files = list_repo_files(repo, revision, repo_type) |
| except Exception as error: |
| raise HTTPException(status_code=500, detail=str(error)) |
|
|
| if not files: |
| raise HTTPException(status_code=404, detail="没有找到文件") |
|
|
| print(f"找到 {len(files)} 个文件", flush=True) |
|
|
| repo_name = safe_filename(repo.split("/")[-1]) |
| revision_name = safe_filename(revision) |
| zip_name = f"{repo_name}-{revision_name}.zip" |
|
|
| headers = { |
| "Content-Disposition": f'attachment; filename="{zip_name}"', |
| "Cache-Control": "no-store", |
| "X-Accel-Buffering": "no", |
| } |
|
|
| return StreamingResponse( |
| stream_repo_as_zip(repo, revision, repo_type, files), |
| media_type="application/zip", |
| headers=headers, |
| ) |
|
|
|
|
| download_js = """ |
| (repo, revision, repoType) => { |
| repo = (repo || "").trim(); |
| revision = (revision || "main").trim(); |
| repoType = (repoType || "model").trim(); |
| |
| if (!repo || !repo.includes("/")) { |
| alert("repo 参数格式错误,应为 owner/name,例如 Qwen/Qwen3.6-35B-A3B"); |
| return [repo, revision, repoType]; |
| } |
| |
| const params = new URLSearchParams({ |
| repo: repo, |
| revision: revision, |
| repoType: repoType |
| }); |
| |
| const url = `/api/hf-model-zip?${params.toString()}`; |
| |
| const a = document.createElement("a"); |
| a.href = url; |
| a.download = ""; |
| a.style.display = "none"; |
| |
| document.body.appendChild(a); |
| a.click(); |
| document.body.removeChild(a); |
| |
| return [repo, revision, repoType]; |
| } |
| """ |
|
|
|
|
| def noop(repo, revision, repo_type): |
| return None |
|
|
|
|
| with gr.Blocks(title="Hugging Face ZIP 流式下载器") as demo: |
| gr.Markdown("# Hugging Face ZIP 流式下载器") |
| gr.Markdown( |
| "输入 Hugging Face 模型、数据集或 Space 仓库名,点击按钮后直接开始流式下载 ZIP。" |
| ) |
|
|
| with gr.Row(): |
| repo_input = gr.Textbox( |
| label="仓库名", |
| value="sshleifer/tiny-gpt2", |
| placeholder="例如:Qwen/Qwen3.6-35B-A3B", |
| ) |
|
|
| revision_input = gr.Textbox( |
| label="分支 / revision", |
| value="main", |
| ) |
|
|
| repo_type_input = gr.Dropdown( |
| label="仓库类型", |
| choices=["model", "dataset", "space"], |
| value="model", |
| ) |
|
|
| download_button = gr.Button("下载 ZIP", variant="primary") |
|
|
| gr.Markdown( |
| """ |
| 说明: |
| |
| - 点击按钮后会直接触发浏览器下载。 |
| """ |
| ) |
|
|
| download_button.click( |
| fn=noop, |
| inputs=[repo_input, revision_input, repo_type_input], |
| outputs=[], |
| js=download_js, |
| queue=False, |
| ) |
|
|
|
|
| app = gr.mount_gradio_app( |
| api_app, |
| demo, |
| path="/", |
| ssr_mode=False, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| port = int( |
| os.environ.get("PORT") |
| or os.environ.get("GRADIO_SERVER_PORT") |
| or "7860" |
| ) |
|
|
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=port, |
| proxy_headers=True, |
| forwarded_allow_ips="*", |
| ) |