lgccccc commited on
Commit
34cd439
·
1 Parent(s): c4d3dc4

Add Gradio zip downloader app

Browse files
Files changed (2) hide show
  1. app.py +177 -3
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,7 +1,181 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ import requests
5
  import gradio as gr
6
+ from urllib.parse import quote
7
+
8
+
9
+ HF_BASE = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+
12
+
13
+ def get_headers():
14
+ headers = {
15
+ "User-Agent": "hf-zip-gradio-demo/1.0"
16
+ }
17
+
18
+ if HF_TOKEN:
19
+ headers["Authorization"] = f"Bearer {HF_TOKEN}"
20
+
21
+ return headers
22
+
23
+
24
+ def encode_repo(repo: str) -> str:
25
+ return "/".join(quote(part, safe="") for part in repo.split("/"))
26
+
27
+
28
+ def encode_path(path: str) -> str:
29
+ return "/".join(quote(part, safe="") for part in path.split("/"))
30
+
31
+
32
+ def get_repo_api_path(repo_type: str, repo: str) -> str:
33
+ encoded_repo = encode_repo(repo)
34
+
35
+ if repo_type == "dataset":
36
+ return f"/api/datasets/{encoded_repo}"
37
+
38
+ if repo_type == "space":
39
+ return f"/api/spaces/{encoded_repo}"
40
+
41
+ return f"/api/models/{encoded_repo}"
42
+
43
+
44
+ def get_repo_resolve_prefix(repo_type: str, repo: str) -> str:
45
+ encoded_repo = encode_repo(repo)
46
+
47
+ if repo_type == "dataset":
48
+ return f"/datasets/{encoded_repo}"
49
+
50
+ if repo_type == "space":
51
+ return f"/spaces/{encoded_repo}"
52
+
53
+ return f"/{encoded_repo}"
54
+
55
+
56
+ def list_repo_files(repo: str, revision: str, repo_type: str):
57
+ api_path = get_repo_api_path(repo_type, repo)
58
+ url = f"{HF_BASE}{api_path}/tree/{quote(revision, safe='')}?recursive=1"
59
+
60
+ response = requests.get(url, headers=get_headers(), timeout=60)
61
+
62
+ if not response.ok:
63
+ raise RuntimeError(f"获取文件列表失败:{response.status_code} {response.text[:500]}")
64
+
65
+ items = response.json()
66
+
67
+ files = []
68
+ for item in items:
69
+ if item.get("type") == "file":
70
+ files.append({
71
+ "path": item.get("path"),
72
+ "size": item.get("size", 0)
73
+ })
74
+
75
+ return files
76
+
77
+
78
+ def get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
79
+ prefix = get_repo_resolve_prefix(repo_type, repo)
80
+
81
+ return (
82
+ f"{HF_BASE}{prefix}"
83
+ f"/resolve/{quote(revision, safe='')}"
84
+ f"/{encode_path(path)}"
85
+ )
86
+
87
+
88
+ def download_repo_as_zip(repo: str, revision: str, repo_type: str, progress=gr.Progress()):
89
+ repo = repo.strip()
90
+ revision = revision.strip() or "main"
91
+
92
+ if not repo or "/" not in repo:
93
+ raise gr.Error("repo 参数格式错误,应为 owner/name,例如 sshleifer/tiny-gpt2")
94
+
95
+ files = list_repo_files(repo, revision, repo_type)
96
+
97
+ if not files:
98
+ raise gr.Error("没有找到文件")
99
+
100
+ safe_repo_name = repo.split("/")[-1].replace("/", "_")
101
+ zip_path = os.path.join(
102
+ tempfile.gettempdir(),
103
+ f"{safe_repo_name}-{revision}.zip"
104
+ )
105
+
106
+ progress(0, desc=f"找到 {len(files)} 个文件,开始打包...")
107
+
108
+ with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_STORED, allowZip64=True) as zip_file:
109
+ for index, file in enumerate(files):
110
+ file_path = file["path"]
111
+
112
+ progress(
113
+ index / len(files),
114
+ desc=f"正在下载并写入:{file_path}"
115
+ )
116
+
117
+ url = get_file_download_url(
118
+ repo=repo,
119
+ repo_type=repo_type,
120
+ revision=revision,
121
+ path=file_path
122
+ )
123
+
124
+ response = requests.get(url, headers=get_headers(), stream=True, timeout=120)
125
+
126
+ if not response.ok:
127
+ raise gr.Error(f"下载文件失败:{file_path},状态码:{response.status_code}")
128
+
129
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
130
+ temp_file_path = temp_file.name
131
+
132
+ try:
133
+ for chunk in response.iter_content(chunk_size=1024 * 1024):
134
+ if chunk:
135
+ temp_file.write(chunk)
136
+
137
+ temp_file.flush()
138
+ zip_file.write(temp_file_path, arcname=file_path)
139
+
140
+ finally:
141
+ if os.path.exists(temp_file_path):
142
+ os.remove(temp_file_path)
143
+
144
+ progress(1, desc="ZIP 生成完成")
145
+
146
+ return zip_path
147
+
148
+
149
+ with gr.Blocks(title="Hugging Face 仓库 ZIP 下载器") as demo:
150
+ gr.Markdown("# Hugging Face 仓库 ZIP 下载器")
151
+ gr.Markdown("输入模型、数据集或 Space 仓库名,生成 ZIP 文件下载。")
152
+
153
+ with gr.Row():
154
+ repo = gr.Textbox(
155
+ label="仓库名",
156
+ value="sshleifer/tiny-gpt2",
157
+ placeholder="例如:Qwen/Qwen2.5-0.5B-Instruct"
158
+ )
159
+
160
+ revision = gr.Textbox(
161
+ label="分支 / revision",
162
+ value="main"
163
+ )
164
+
165
+ repo_type = gr.Dropdown(
166
+ label="仓库类型",
167
+ choices=["model", "dataset", "space"],
168
+ value="model"
169
+ )
170
+
171
+ button = gr.Button("生成 ZIP")
172
+ output_file = gr.File(label="下载 ZIP 文件")
173
+
174
+ button.click(
175
+ fn=download_repo_as_zip,
176
+ inputs=[repo, revision, repo_type],
177
+ outputs=output_file
178
+ )
179
 
 
 
180
 
 
181
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ requests