Add Gradio zip downloader app
Browse files- app.py +177 -3
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,7 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def greet(name):
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import zipfile
|
| 3 |
+
import tempfile
|
| 4 |
+
import requests
|
| 5 |
import gradio as gr
|
| 6 |
+
from urllib.parse import quote
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
HF_BASE = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
| 10 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_headers():
|
| 14 |
+
headers = {
|
| 15 |
+
"User-Agent": "hf-zip-gradio-demo/1.0"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
if HF_TOKEN:
|
| 19 |
+
headers["Authorization"] = f"Bearer {HF_TOKEN}"
|
| 20 |
+
|
| 21 |
+
return headers
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def encode_repo(repo: str) -> str:
|
| 25 |
+
return "/".join(quote(part, safe="") for part in repo.split("/"))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def encode_path(path: str) -> str:
|
| 29 |
+
return "/".join(quote(part, safe="") for part in path.split("/"))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_repo_api_path(repo_type: str, repo: str) -> str:
|
| 33 |
+
encoded_repo = encode_repo(repo)
|
| 34 |
+
|
| 35 |
+
if repo_type == "dataset":
|
| 36 |
+
return f"/api/datasets/{encoded_repo}"
|
| 37 |
+
|
| 38 |
+
if repo_type == "space":
|
| 39 |
+
return f"/api/spaces/{encoded_repo}"
|
| 40 |
+
|
| 41 |
+
return f"/api/models/{encoded_repo}"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_repo_resolve_prefix(repo_type: str, repo: str) -> str:
|
| 45 |
+
encoded_repo = encode_repo(repo)
|
| 46 |
+
|
| 47 |
+
if repo_type == "dataset":
|
| 48 |
+
return f"/datasets/{encoded_repo}"
|
| 49 |
+
|
| 50 |
+
if repo_type == "space":
|
| 51 |
+
return f"/spaces/{encoded_repo}"
|
| 52 |
+
|
| 53 |
+
return f"/{encoded_repo}"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def list_repo_files(repo: str, revision: str, repo_type: str):
|
| 57 |
+
api_path = get_repo_api_path(repo_type, repo)
|
| 58 |
+
url = f"{HF_BASE}{api_path}/tree/{quote(revision, safe='')}?recursive=1"
|
| 59 |
+
|
| 60 |
+
response = requests.get(url, headers=get_headers(), timeout=60)
|
| 61 |
+
|
| 62 |
+
if not response.ok:
|
| 63 |
+
raise RuntimeError(f"获取文件列表失败:{response.status_code} {response.text[:500]}")
|
| 64 |
+
|
| 65 |
+
items = response.json()
|
| 66 |
+
|
| 67 |
+
files = []
|
| 68 |
+
for item in items:
|
| 69 |
+
if item.get("type") == "file":
|
| 70 |
+
files.append({
|
| 71 |
+
"path": item.get("path"),
|
| 72 |
+
"size": item.get("size", 0)
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
return files
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
|
| 79 |
+
prefix = get_repo_resolve_prefix(repo_type, repo)
|
| 80 |
+
|
| 81 |
+
return (
|
| 82 |
+
f"{HF_BASE}{prefix}"
|
| 83 |
+
f"/resolve/{quote(revision, safe='')}"
|
| 84 |
+
f"/{encode_path(path)}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def download_repo_as_zip(repo: str, revision: str, repo_type: str, progress=gr.Progress()):
|
| 89 |
+
repo = repo.strip()
|
| 90 |
+
revision = revision.strip() or "main"
|
| 91 |
+
|
| 92 |
+
if not repo or "/" not in repo:
|
| 93 |
+
raise gr.Error("repo 参数格式错误,应为 owner/name,例如 sshleifer/tiny-gpt2")
|
| 94 |
+
|
| 95 |
+
files = list_repo_files(repo, revision, repo_type)
|
| 96 |
+
|
| 97 |
+
if not files:
|
| 98 |
+
raise gr.Error("没有找到文件")
|
| 99 |
+
|
| 100 |
+
safe_repo_name = repo.split("/")[-1].replace("/", "_")
|
| 101 |
+
zip_path = os.path.join(
|
| 102 |
+
tempfile.gettempdir(),
|
| 103 |
+
f"{safe_repo_name}-{revision}.zip"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
progress(0, desc=f"找到 {len(files)} 个文件,开始打包...")
|
| 107 |
+
|
| 108 |
+
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_STORED, allowZip64=True) as zip_file:
|
| 109 |
+
for index, file in enumerate(files):
|
| 110 |
+
file_path = file["path"]
|
| 111 |
+
|
| 112 |
+
progress(
|
| 113 |
+
index / len(files),
|
| 114 |
+
desc=f"正在下载并写入:{file_path}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
url = get_file_download_url(
|
| 118 |
+
repo=repo,
|
| 119 |
+
repo_type=repo_type,
|
| 120 |
+
revision=revision,
|
| 121 |
+
path=file_path
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
response = requests.get(url, headers=get_headers(), stream=True, timeout=120)
|
| 125 |
+
|
| 126 |
+
if not response.ok:
|
| 127 |
+
raise gr.Error(f"下载文件失败:{file_path},状态码:{response.status_code}")
|
| 128 |
+
|
| 129 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
| 130 |
+
temp_file_path = temp_file.name
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024):
|
| 134 |
+
if chunk:
|
| 135 |
+
temp_file.write(chunk)
|
| 136 |
+
|
| 137 |
+
temp_file.flush()
|
| 138 |
+
zip_file.write(temp_file_path, arcname=file_path)
|
| 139 |
+
|
| 140 |
+
finally:
|
| 141 |
+
if os.path.exists(temp_file_path):
|
| 142 |
+
os.remove(temp_file_path)
|
| 143 |
+
|
| 144 |
+
progress(1, desc="ZIP 生成完成")
|
| 145 |
+
|
| 146 |
+
return zip_path
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
with gr.Blocks(title="Hugging Face 仓库 ZIP 下载器") as demo:
|
| 150 |
+
gr.Markdown("# Hugging Face 仓库 ZIP 下载器")
|
| 151 |
+
gr.Markdown("输入模型、数据集或 Space 仓库名,生成 ZIP 文件下载。")
|
| 152 |
+
|
| 153 |
+
with gr.Row():
|
| 154 |
+
repo = gr.Textbox(
|
| 155 |
+
label="仓库名",
|
| 156 |
+
value="sshleifer/tiny-gpt2",
|
| 157 |
+
placeholder="例如:Qwen/Qwen2.5-0.5B-Instruct"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
revision = gr.Textbox(
|
| 161 |
+
label="分支 / revision",
|
| 162 |
+
value="main"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
repo_type = gr.Dropdown(
|
| 166 |
+
label="仓库类型",
|
| 167 |
+
choices=["model", "dataset", "space"],
|
| 168 |
+
value="model"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
button = gr.Button("生成 ZIP")
|
| 172 |
+
output_file = gr.File(label="下载 ZIP 文件")
|
| 173 |
+
|
| 174 |
+
button.click(
|
| 175 |
+
fn=download_repo_as_zip,
|
| 176 |
+
inputs=[repo, revision, repo_type],
|
| 177 |
+
outputs=output_file
|
| 178 |
+
)
|
| 179 |
|
|
|
|
|
|
|
| 180 |
|
|
|
|
| 181 |
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
requests
|