Spaces:
Running
Running
File size: 16,344 Bytes
dfa6c7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import gradio as gr
from huggingface_hub import HfApi
import os
import git
import requests
import tempfile
import shutil
import json
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import base64
import sys
import io
from urllib.parse import unquote
import time
# --- 全局设置 ---
TEMP_DIR = tempfile.mkdtemp()
# --- 日志捕获工具 ---
class LogCapture:
def __init__(self): self.log_stream = io.StringIO()
def __enter__(self): self.old_stdout = sys.stdout; sys.stdout = self.log_stream; return self
def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout = self.old_stdout
def get_value(self): return self.log_stream.getvalue()
# --- 后端逻辑函数 ---
def prepare_repo(hf_token, repo_id, repo_type, space_sdk):
if not hf_token: raise gr.Error("必须提供 Hugging Face Token!")
if not repo_id: raise gr.Error("必须提供目标 Hugging Face 仓库ID!")
api = HfApi(token=hf_token)
try:
sdk_to_use = space_sdk if repo_type == "space" else None
api.create_repo(repo_id=repo_id, repo_type=repo_type, space_sdk=sdk_to_use, exist_ok=True)
return api
except Exception as e: raise gr.Error(f"创建或访问仓库失败: {e}")
class CloneProgress(git.remote.RemoteProgress):
def __init__(self, progress_bar): super().__init__(); self.progress = progress_bar
def update(self, op_code, cur_count, max_count=None, message=''):
if max_count: self.progress(cur_count / max_count, desc=f"克隆中: {message.strip()} ({int(cur_count)}/{int(max_count)})")
def clone_github_repo(github_url, progress=gr.Progress()):
if not github_url: raise gr.Error("必须提供 GitHub 仓库链接!")
progress(0, desc="准备克隆...")
repo_name = github_url.split('/')[-1].replace('.git', '')
clone_path = os.path.join(TEMP_DIR, repo_name)
if os.path.exists(clone_path): shutil.rmtree(clone_path)
try: git.Repo.clone_from(github_url, clone_path, progress=CloneProgress(progress))
except Exception as e: raise gr.Error(f"克隆仓库失败: {e}")
progress(1, "克隆完成!")
items = ["上传整个仓库"] + os.listdir(clone_path)
return gr.update(choices=items, value="上传整个仓库", visible=True), gr.update(visible=True), f"✅ 成功克隆 '{repo_name}'。", clone_path
def update_folder_upload_options(selection, clone_path):
if selection and clone_path and selection != "上传整个仓库":
if os.path.isdir(os.path.join(clone_path, selection)): return gr.update(visible=True)
return gr.update(visible=False)
def upload_from_github(hf_token, repo_id, repo_type, space_sdk, path_in_repo, selection, folder_upload_mode, clone_path, progress=gr.Progress()):
if not clone_path: raise gr.Error("克隆路径状态丢失,请重新克隆仓库!")
progress(0, desc="准备上传...")
api = prepare_repo(hf_token, repo_id, repo_type, space_sdk)
source_path = os.path.join(clone_path, selection)
try:
if selection == "上传整个仓库":
progress(0.5, desc=f"正在上传整个仓库 (忽略 .git)...")
api.upload_folder(folder_path=clone_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=path_in_repo, ignore_patterns=[".git", ".gitignore"])
message = "📂 成功上传了整个仓库"
elif os.path.isdir(source_path):
final_path_in_repo = path_in_repo if folder_upload_mode == "仅上传文件夹内容" else os.path.join(path_in_repo, selection)
progress(0.5, desc=f"正在上传文件夹 '{selection}'...")
api.upload_folder(folder_path=source_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=final_path_in_repo)
message = f"📂 成功上传文件夹 '{selection}'"
else:
progress(0.5, desc=f"正在上传文件: {selection}")
api.upload_file(path_or_fileobj=source_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=os.path.join(path_in_repo, selection))
message = f"📄 成功上传文件 '{selection}'"
progress(1, "上传完成!")
return f"✅ {message} 到 '{repo_id}'."
except Exception as e: raise gr.Error(f"上传失败: {e}")
def download_file_with_progress(url, download_dir, single_file_progress=None):
try:
with requests.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
filename = unquote(url.split('/')[-1].split('?')[0] or "downloaded_file")
filepath = os.path.join(download_dir, filename)
total_size = int(r.headers.get('content-length', 0))
with open(filepath, 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"下载 {filename}", ascii=True, file=sys.stdout) as t:
chunk_size = 8192
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
t.update(len(chunk))
if single_file_progress and total_size > 0:
downloaded = t.n; single_file_progress(downloaded / total_size, desc=f"下载中: {filename} ({downloaded/1e6:.2f}/{total_size/1e6:.2f} MB)")
return filepath
except Exception as e: print(f"下载失败 {url}: {e}"); return e
def download_from_url(url_text, mode, thread_count, progress=gr.Progress()):
if not url_text: raise gr.Error("必须提供下载链接!")
urls = [url.strip() for url in url_text.split('\n') if url.strip()]; download_dir = os.path.join(TEMP_DIR, "downloads"); os.makedirs(download_dir, exist_ok=True); downloaded_files, errors = [], []; log_capture = LogCapture()
yield gr.update(), gr.update(visible=False), gr.update(), gr.update(value="", visible=True), gr.update()
with log_capture:
if mode == "单线程":
for i, url in enumerate(urls):
progress(i / len(urls), desc=f"开始下载第 {i+1}/{len(urls)} 个文件..."); result = download_file_with_progress(url, download_dir, progress); yield gr.update(), gr.update(), gr.update(), gr.update(value=log_capture.get_value()), gr.update()
if isinstance(result, str) and os.path.exists(result): downloaded_files.append(result)
else: errors.append(result)
else:
threads = []
with ThreadPoolExecutor(max_workers=thread_count) as executor:
for url in urls: threads.append(executor.submit(download_file_with_progress, url, download_dir))
while any(t.running() for t in threads):
completed_count = sum(1 for t in threads if t.done()); progress(completed_count / len(threads), desc=f"已完成 {completed_count}/{len(threads)} 个文件下载任务"); yield gr.update(), gr.update(), gr.update(), gr.update(value=log_capture.get_value()), gr.update(); time.sleep(0.5)
for thread in threads:
result = thread.result()
if isinstance(result, str) and os.path.exists(result): downloaded_files.append(result)
else: errors.append(result)
progress(1, "所有下载任务已完成!"); status_msg = f"✅ 成功下载 {len(downloaded_files)} 个文件。";
if errors: status_msg += f" ❌ 失败 {len(errors)} 个: {', '.join(map(str, errors[:2]))}...";
file_names = [os.path.basename(f) for f in downloaded_files]; final_log = log_capture.get_value()
yield gr.update(choices=file_names, value=file_names, visible=True), gr.update(visible=True), status_msg, gr.update(value=final_log), download_dir
def upload_from_url(hf_token, repo_id, repo_type, space_sdk, path_in_repo, selected_files, download_dir, progress=gr.Progress()):
if not download_dir: raise gr.Error("下载目录状态丢失,请重新下载文件!");
if not selected_files: raise gr.Error("请选择要上传的文件!");
api = prepare_repo(hf_token, repo_id, repo_type, space_sdk)
for i, filename in enumerate(selected_files): progress((i + 1) / len(selected_files), desc=f"正在上传 {filename}..."); api.upload_file(path_or_fileobj=os.path.join(download_dir, filename), path_in_repo=os.path.join(path_in_repo, filename), repo_id=repo_id, repo_type=repo_type)
return f"✅ 成功将 {len(selected_files)} 个文件上传到 '{repo_id}'."
def upload_from_local(hf_token, repo_id, repo_type, space_sdk, path_in_repo, local_files, progress=gr.Progress()):
if not local_files: raise gr.Error("请至少上传一个文件!");
api = prepare_repo(hf_token, repo_id, repo_type, space_sdk);
for i, file_obj in enumerate(local_files): progress((i + 1) / len(local_files), desc=f"正在上传 {os.path.basename(file_obj.name)}..."); api.upload_file(path_or_fileobj=file_obj.name, path_in_repo=os.path.join(path_in_repo, os.path.basename(file_obj.name)), repo_id=repo_id, repo_type=repo_type)
return f"✅ 成功上传 {len(local_files)} 个本地文件到 '{repo_id}'."
def generate_share_link(repo_id, repo_type, space_sdk, path_in_repo, github_url, url_text, download_mode, thread_count):
config = {"repo_id": repo_id, "repo_type": repo_type, "space_sdk": space_sdk, "path_in_repo": path_in_repo,"github_url": github_url, "url_text": url_text, "download_mode": download_mode, "thread_count": thread_count}
encoded = base64.urlsafe_b64encode(json.dumps(config).encode()).decode(); return f"将此编码粘贴到URL参数 `?config=` 后面: \n\n{encoded}"
def apply_config(config_data):
try:
if isinstance(config_data, str): config = json.loads(base64.urlsafe_b64decode(config_data).decode())
else:
with open(config_data.name, "r") as f: config = json.load(f)
except Exception as e: raise gr.Error(f"解析配置失败: {e}")
repo_type = config.get("repo_type", "space"); mode = config.get("download_mode", "多线程")
return gr.update(value=config.get("repo_id", "")),gr.update(value=repo_type),gr.update(value=config.get("space_sdk", "docker"), visible=(repo_type == "space")),gr.update(value=config.get("path_in_repo", "")),gr.update(value=config.get("github_url", "")),gr.update(value=config.get("url_text", "")),gr.update(value=mode),gr.update(value=config.get("thread_count", 8), visible=(mode == "多线程"))
with gr.Blocks(theme=gr.themes.Soft(), title="云端同步工具 Pro") as demo:
gr.Markdown("# ☁️ 云端同步工具 Pro (V11 最终定义修复版)")
global_progress = gr.Progress()
status_output = gr.Markdown()
with gr.Accordion("⚙️ 导入/导出配置", open=False):
with gr.Row(): gen_link_btn = gr.Button("生成分享配置")
share_link_output = gr.Textbox(label="分享配置编码", interactive=False, max_lines=3, visible=False)
with gr.Row():
config_link_input = gr.Textbox(label="粘贴配置编码", scale=3)
config_file_input = gr.UploadButton("或上传配置文件", file_types=[".json"])
apply_config_btn = gr.Button("应用配置", variant="primary")
gr.Markdown("---")
gr.Markdown("### **第一步:设置目标仓库**")
with gr.Row(): hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", info="请提供有 'write' 写入权限的Token。")
with gr.Row():
hf_repo_input = gr.Textbox(label="目标 HF 仓库 ID", scale=2)
hf_repo_type_input = gr.Dropdown(label="仓库类型", choices=["space", "dataset", "model"], value="space")
hf_space_sdk_input = gr.Dropdown(label="Space SDK", choices=["docker", "gradio", "streamlit", "static"], value="docker", visible=True)
hf_path_in_repo_input = gr.Textbox(label="仓库内路径 (可选)")
def toggle_sdk_visibility(repo_type): return gr.update(visible=repo_type == "space")
hf_repo_type_input.change(fn=toggle_sdk_visibility, inputs=hf_repo_type_input, outputs=hf_space_sdk_input)
gr.Markdown("### **第二步:选择同步源并执行**")
with gr.Tabs():
with gr.TabItem("从 GitHub 同步"):
github_url_input = gr.Textbox(label="GitHub 仓库链接")
clone_btn = gr.Button("1. 克隆", variant="secondary")
upload_github_btn = gr.Button("2. 上传到 Hugging Face", variant="primary", visible=False)
github_selection_radio = gr.Radio(label="选择要上传的内容", visible=False)
github_folder_mode_radio = gr.Radio(["上传文件夹本身", "仅上传文件夹内容"], label="文件夹上传模式", value="上传文件夹本身", visible=False)
github_clone_path_state = gr.State()
with gr.TabItem("从 URL 链接下载"):
url_input = gr.TextArea(label="下载链接 (每行一个)")
with gr.Row():
url_download_mode = gr.Radio(["单线程", "多线程"], label="下载模式", value="多线程")
url_thread_count = gr.Number(label="线程数", value=8, minimum=1, maximum=32, step=1, visible=True)
def toggle_thread_count(mode): return gr.update(visible=mode == "多线程")
url_download_mode.change(fn=toggle_thread_count, inputs=url_download_mode, outputs=url_thread_count)
download_btn = gr.Button("1. 下载", variant="secondary")
upload_url_btn = gr.Button("2. 上传到 Hugging Face", variant="primary", visible=False)
url_selection_checkbox = gr.CheckboxGroup(label="选择要上传的文件", visible=False)
download_log_output = gr.Textbox(label="下载日志", lines=10, interactive=False, visible=False)
url_download_dir_state = gr.State()
with gr.TabItem("从本地上传"):
local_file_input = gr.File(label="选择本地文件", file_count="multiple")
upload_local_btn = gr.Button("上传到 Hugging Face", variant="primary")
# --- V11 核心修复: 将所有变量列表的定义移到最前面 ---
all_config_inputs = [hf_repo_input, hf_repo_type_input, hf_space_sdk_input, hf_path_in_repo_input, github_url_input, url_input, url_download_mode, url_thread_count]
upload_inputs = [hf_token_input, hf_repo_input, hf_repo_type_input, hf_space_sdk_input, hf_path_in_repo_input]
# --- 绑定UI和逻辑 ---
gen_link_btn.click(fn=generate_share_link, inputs=all_config_inputs, outputs=share_link_output).then(lambda: gr.update(visible=True), outputs=share_link_output)
def apply_wrapper(config_data, config_link): return apply_config(config_data if config_data else config_link)
apply_config_btn.click(fn=apply_wrapper, inputs=[config_file_input, config_link_input], outputs=all_config_inputs)
clone_btn.click(fn=clone_github_repo, inputs=[github_url_input], outputs=[github_selection_radio, upload_github_btn, status_output, github_clone_path_state], api_name="clone_github")
github_selection_radio.change(fn=update_folder_upload_options, inputs=[github_selection_radio, github_clone_path_state], outputs=github_folder_mode_radio)
upload_github_btn.click(fn=upload_from_github, inputs=upload_inputs + [github_selection_radio, github_folder_mode_radio, github_clone_path_state], outputs=[status_output], api_name="upload_github")
download_btn.click(fn=download_from_url, inputs=[url_input, url_download_mode, url_thread_count], outputs=[url_selection_checkbox, upload_url_btn, status_output, download_log_output, url_download_dir_state], api_name="download_url")
upload_url_btn.click(fn=upload_from_url, inputs=upload_inputs + [url_selection_checkbox, url_download_dir_state], outputs=[status_output], api_name="upload_url")
upload_local_btn.click(fn=upload_from_local, inputs=upload_inputs + [local_file_input], outputs=[status_output], api_name="upload_local")
def get_url_params(request: gr.Request):
if request and "config" in request.query_params: return apply_config(request.query_params["config"])
return (gr.update(),) * len(all_config_inputs)
demo.load(get_url_params, inputs=None, outputs=all_config_inputs)
demo.unload(lambda: shutil.rmtree(TEMP_DIR, ignore_errors=True))
if __name__ == "__main__":
demo.launch(debug=True) |