File size: 16,344 Bytes
dfa6c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import gradio as gr
from huggingface_hub import HfApi
import os
import git
import requests
import tempfile
import shutil
import json
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import base64
import sys
import io
from urllib.parse import unquote
import time

# --- 全局设置 ---
TEMP_DIR = tempfile.mkdtemp()

# --- 日志捕获工具 ---
class LogCapture:
    def __init__(self): self.log_stream = io.StringIO()
    def __enter__(self): self.old_stdout = sys.stdout; sys.stdout = self.log_stream; return self
    def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout = self.old_stdout
    def get_value(self): return self.log_stream.getvalue()

# --- 后端逻辑函数 ---
def prepare_repo(hf_token, repo_id, repo_type, space_sdk):
    if not hf_token: raise gr.Error("必须提供 Hugging Face Token!")
    if not repo_id: raise gr.Error("必须提供目标 Hugging Face 仓库ID!")
    api = HfApi(token=hf_token)
    try:
        sdk_to_use = space_sdk if repo_type == "space" else None
        api.create_repo(repo_id=repo_id, repo_type=repo_type, space_sdk=sdk_to_use, exist_ok=True)
        return api
    except Exception as e: raise gr.Error(f"创建或访问仓库失败: {e}")

class CloneProgress(git.remote.RemoteProgress):
    def __init__(self, progress_bar): super().__init__(); self.progress = progress_bar
    def update(self, op_code, cur_count, max_count=None, message=''):
        if max_count: self.progress(cur_count / max_count, desc=f"克隆中: {message.strip()} ({int(cur_count)}/{int(max_count)})")

def clone_github_repo(github_url, progress=gr.Progress()):
    if not github_url: raise gr.Error("必须提供 GitHub 仓库链接!")
    progress(0, desc="准备克隆...")
    repo_name = github_url.split('/')[-1].replace('.git', '')
    clone_path = os.path.join(TEMP_DIR, repo_name)
    if os.path.exists(clone_path): shutil.rmtree(clone_path)
    try: git.Repo.clone_from(github_url, clone_path, progress=CloneProgress(progress))
    except Exception as e: raise gr.Error(f"克隆仓库失败: {e}")
    progress(1, "克隆完成!")
    items = ["上传整个仓库"] + os.listdir(clone_path)
    return gr.update(choices=items, value="上传整个仓库", visible=True), gr.update(visible=True), f"✅ 成功克隆 '{repo_name}'。", clone_path

def update_folder_upload_options(selection, clone_path):
    if selection and clone_path and selection != "上传整个仓库":
        if os.path.isdir(os.path.join(clone_path, selection)): return gr.update(visible=True)
    return gr.update(visible=False)

def upload_from_github(hf_token, repo_id, repo_type, space_sdk, path_in_repo, selection, folder_upload_mode, clone_path, progress=gr.Progress()):
    if not clone_path: raise gr.Error("克隆路径状态丢失,请重新克隆仓库!")
    progress(0, desc="准备上传...")
    api = prepare_repo(hf_token, repo_id, repo_type, space_sdk)
    source_path = os.path.join(clone_path, selection)
    try:
        if selection == "上传整个仓库":
            progress(0.5, desc=f"正在上传整个仓库 (忽略 .git)...")
            api.upload_folder(folder_path=clone_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=path_in_repo, ignore_patterns=[".git", ".gitignore"])
            message = "📂 成功上传了整个仓库"
        elif os.path.isdir(source_path):
            final_path_in_repo = path_in_repo if folder_upload_mode == "仅上传文件夹内容" else os.path.join(path_in_repo, selection)
            progress(0.5, desc=f"正在上传文件夹 '{selection}'...")
            api.upload_folder(folder_path=source_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=final_path_in_repo)
            message = f"📂 成功上传文件夹 '{selection}'"
        else:
            progress(0.5, desc=f"正在上传文件: {selection}")
            api.upload_file(path_or_fileobj=source_path, repo_id=repo_id, repo_type=repo_type, path_in_repo=os.path.join(path_in_repo, selection))
            message = f"📄 成功上传文件 '{selection}'"
        progress(1, "上传完成!")
        return f"✅ {message} 到 '{repo_id}'."
    except Exception as e: raise gr.Error(f"上传失败: {e}")

def download_file_with_progress(url, download_dir, single_file_progress=None):
    try:
        with requests.get(url, stream=True, timeout=30) as r:
            r.raise_for_status()
            filename = unquote(url.split('/')[-1].split('?')[0] or "downloaded_file")
            filepath = os.path.join(download_dir, filename)
            total_size = int(r.headers.get('content-length', 0))
            with open(filepath, 'wb') as f, tqdm(total=total_size, unit='B', unit_scale=True, unit_divisor=1024, desc=f"下载 {filename}", ascii=True, file=sys.stdout) as t:
                chunk_size = 8192
                for chunk in r.iter_content(chunk_size=chunk_size):
                    f.write(chunk)
                    t.update(len(chunk))
                    if single_file_progress and total_size > 0:
                        downloaded = t.n; single_file_progress(downloaded / total_size, desc=f"下载中: {filename} ({downloaded/1e6:.2f}/{total_size/1e6:.2f} MB)")
        return filepath
    except Exception as e: print(f"下载失败 {url}: {e}"); return e

def download_from_url(url_text, mode, thread_count, progress=gr.Progress()):
    if not url_text: raise gr.Error("必须提供下载链接!")
    urls = [url.strip() for url in url_text.split('\n') if url.strip()]; download_dir = os.path.join(TEMP_DIR, "downloads"); os.makedirs(download_dir, exist_ok=True); downloaded_files, errors = [], []; log_capture = LogCapture()
    yield gr.update(), gr.update(visible=False), gr.update(), gr.update(value="", visible=True), gr.update()
    with log_capture:
        if mode == "单线程":
            for i, url in enumerate(urls):
                progress(i / len(urls), desc=f"开始下载第 {i+1}/{len(urls)} 个文件..."); result = download_file_with_progress(url, download_dir, progress); yield gr.update(), gr.update(), gr.update(), gr.update(value=log_capture.get_value()), gr.update()
                if isinstance(result, str) and os.path.exists(result): downloaded_files.append(result)
                else: errors.append(result)
        else:
            threads = []
            with ThreadPoolExecutor(max_workers=thread_count) as executor:
                for url in urls: threads.append(executor.submit(download_file_with_progress, url, download_dir))
                while any(t.running() for t in threads):
                    completed_count = sum(1 for t in threads if t.done()); progress(completed_count / len(threads), desc=f"已完成 {completed_count}/{len(threads)} 个文件下载任务"); yield gr.update(), gr.update(), gr.update(), gr.update(value=log_capture.get_value()), gr.update(); time.sleep(0.5)
            for thread in threads:
                result = thread.result()
                if isinstance(result, str) and os.path.exists(result): downloaded_files.append(result)
                else: errors.append(result)
    progress(1, "所有下载任务已完成!"); status_msg = f"✅ 成功下载 {len(downloaded_files)} 个文件。";
    if errors: status_msg += f" ❌ 失败 {len(errors)} 个: {', '.join(map(str, errors[:2]))}...";
    file_names = [os.path.basename(f) for f in downloaded_files]; final_log = log_capture.get_value()
    yield gr.update(choices=file_names, value=file_names, visible=True), gr.update(visible=True), status_msg, gr.update(value=final_log), download_dir

def upload_from_url(hf_token, repo_id, repo_type, space_sdk, path_in_repo, selected_files, download_dir, progress=gr.Progress()):
    if not download_dir: raise gr.Error("下载目录状态丢失,请重新下载文件!");
    if not selected_files: raise gr.Error("请选择要上传的文件!");
    api = prepare_repo(hf_token, repo_id, repo_type, space_sdk)
    for i, filename in enumerate(selected_files): progress((i + 1) / len(selected_files), desc=f"正在上传 {filename}..."); api.upload_file(path_or_fileobj=os.path.join(download_dir, filename), path_in_repo=os.path.join(path_in_repo, filename), repo_id=repo_id, repo_type=repo_type)
    return f"✅ 成功将 {len(selected_files)} 个文件上传到 '{repo_id}'."

def upload_from_local(hf_token, repo_id, repo_type, space_sdk, path_in_repo, local_files, progress=gr.Progress()):
    if not local_files: raise gr.Error("请至少上传一个文件!");
    api = prepare_repo(hf_token, repo_id, repo_type, space_sdk);
    for i, file_obj in enumerate(local_files): progress((i + 1) / len(local_files), desc=f"正在上传 {os.path.basename(file_obj.name)}..."); api.upload_file(path_or_fileobj=file_obj.name, path_in_repo=os.path.join(path_in_repo, os.path.basename(file_obj.name)), repo_id=repo_id, repo_type=repo_type)
    return f"✅ 成功上传 {len(local_files)} 个本地文件到 '{repo_id}'."

def generate_share_link(repo_id, repo_type, space_sdk, path_in_repo, github_url, url_text, download_mode, thread_count):
    config = {"repo_id": repo_id, "repo_type": repo_type, "space_sdk": space_sdk, "path_in_repo": path_in_repo,"github_url": github_url, "url_text": url_text, "download_mode": download_mode, "thread_count": thread_count}
    encoded = base64.urlsafe_b64encode(json.dumps(config).encode()).decode(); return f"将此编码粘贴到URL参数 `?config=` 后面: \n\n{encoded}"

def apply_config(config_data):
    try:
        if isinstance(config_data, str): config = json.loads(base64.urlsafe_b64decode(config_data).decode())
        else:
            with open(config_data.name, "r") as f: config = json.load(f)
    except Exception as e: raise gr.Error(f"解析配置失败: {e}")
    repo_type = config.get("repo_type", "space"); mode = config.get("download_mode", "多线程")
    return gr.update(value=config.get("repo_id", "")),gr.update(value=repo_type),gr.update(value=config.get("space_sdk", "docker"), visible=(repo_type == "space")),gr.update(value=config.get("path_in_repo", "")),gr.update(value=config.get("github_url", "")),gr.update(value=config.get("url_text", "")),gr.update(value=mode),gr.update(value=config.get("thread_count", 8), visible=(mode == "多线程"))

with gr.Blocks(theme=gr.themes.Soft(), title="云端同步工具 Pro") as demo:
    gr.Markdown("# ☁️ 云端同步工具 Pro (V11 最终定义修复版)")
    global_progress = gr.Progress()
    status_output = gr.Markdown()
    with gr.Accordion("⚙️ 导入/导出配置", open=False):
        with gr.Row(): gen_link_btn = gr.Button("生成分享配置")
        share_link_output = gr.Textbox(label="分享配置编码", interactive=False, max_lines=3, visible=False)
        with gr.Row():
            config_link_input = gr.Textbox(label="粘贴配置编码", scale=3)
            config_file_input = gr.UploadButton("或上传配置文件", file_types=[".json"])
        apply_config_btn = gr.Button("应用配置", variant="primary")
    gr.Markdown("---")
    gr.Markdown("### **第一步:设置目标仓库**")
    with gr.Row(): hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", info="请提供有 'write' 写入权限的Token。")
    with gr.Row():
        hf_repo_input = gr.Textbox(label="目标 HF 仓库 ID", scale=2)
        hf_repo_type_input = gr.Dropdown(label="仓库类型", choices=["space", "dataset", "model"], value="space")
        hf_space_sdk_input = gr.Dropdown(label="Space SDK", choices=["docker", "gradio", "streamlit", "static"], value="docker", visible=True)
        hf_path_in_repo_input = gr.Textbox(label="仓库内路径 (可选)")
    def toggle_sdk_visibility(repo_type): return gr.update(visible=repo_type == "space")
    hf_repo_type_input.change(fn=toggle_sdk_visibility, inputs=hf_repo_type_input, outputs=hf_space_sdk_input)
    gr.Markdown("### **第二步:选择同步源并执行**")
    with gr.Tabs():
        with gr.TabItem("从 GitHub 同步"):
            github_url_input = gr.Textbox(label="GitHub 仓库链接")
            clone_btn = gr.Button("1. 克隆", variant="secondary")
            upload_github_btn = gr.Button("2. 上传到 Hugging Face", variant="primary", visible=False)
            github_selection_radio = gr.Radio(label="选择要上传的内容", visible=False)
            github_folder_mode_radio = gr.Radio(["上传文件夹本身", "仅上传文件夹内容"], label="文件夹上传模式", value="上传文件夹本身", visible=False)
            github_clone_path_state = gr.State()
        with gr.TabItem("从 URL 链接下载"):
            url_input = gr.TextArea(label="下载链接 (每行一个)")
            with gr.Row():
                url_download_mode = gr.Radio(["单线程", "多线程"], label="下载模式", value="多线程")
                url_thread_count = gr.Number(label="线程数", value=8, minimum=1, maximum=32, step=1, visible=True)
            def toggle_thread_count(mode): return gr.update(visible=mode == "多线程")
            url_download_mode.change(fn=toggle_thread_count, inputs=url_download_mode, outputs=url_thread_count)
            download_btn = gr.Button("1. 下载", variant="secondary")
            upload_url_btn = gr.Button("2. 上传到 Hugging Face", variant="primary", visible=False)
            url_selection_checkbox = gr.CheckboxGroup(label="选择要上传的文件", visible=False)
            download_log_output = gr.Textbox(label="下载日志", lines=10, interactive=False, visible=False)
            url_download_dir_state = gr.State()
        with gr.TabItem("从本地上传"):
            local_file_input = gr.File(label="选择本地文件", file_count="multiple")
            upload_local_btn = gr.Button("上传到 Hugging Face", variant="primary")
            
    # --- V11 核心修复: 将所有变量列表的定义移到最前面 ---
    all_config_inputs = [hf_repo_input, hf_repo_type_input, hf_space_sdk_input, hf_path_in_repo_input, github_url_input, url_input, url_download_mode, url_thread_count]
    upload_inputs = [hf_token_input, hf_repo_input, hf_repo_type_input, hf_space_sdk_input, hf_path_in_repo_input]
    
    # --- 绑定UI和逻辑 ---
    gen_link_btn.click(fn=generate_share_link, inputs=all_config_inputs, outputs=share_link_output).then(lambda: gr.update(visible=True), outputs=share_link_output)
    def apply_wrapper(config_data, config_link): return apply_config(config_data if config_data else config_link)
    apply_config_btn.click(fn=apply_wrapper, inputs=[config_file_input, config_link_input], outputs=all_config_inputs)
    clone_btn.click(fn=clone_github_repo, inputs=[github_url_input], outputs=[github_selection_radio, upload_github_btn, status_output, github_clone_path_state], api_name="clone_github")
    github_selection_radio.change(fn=update_folder_upload_options, inputs=[github_selection_radio, github_clone_path_state], outputs=github_folder_mode_radio)
    upload_github_btn.click(fn=upload_from_github, inputs=upload_inputs + [github_selection_radio, github_folder_mode_radio, github_clone_path_state], outputs=[status_output], api_name="upload_github")
    download_btn.click(fn=download_from_url, inputs=[url_input, url_download_mode, url_thread_count], outputs=[url_selection_checkbox, upload_url_btn, status_output, download_log_output, url_download_dir_state], api_name="download_url")
    upload_url_btn.click(fn=upload_from_url, inputs=upload_inputs + [url_selection_checkbox, url_download_dir_state], outputs=[status_output], api_name="upload_url")
    upload_local_btn.click(fn=upload_from_local, inputs=upload_inputs + [local_file_input], outputs=[status_output], api_name="upload_local")
    def get_url_params(request: gr.Request):
        if request and "config" in request.query_params: return apply_config(request.query_params["config"])
        return (gr.update(),) * len(all_config_inputs)
    demo.load(get_url_params, inputs=None, outputs=all_config_inputs)
    demo.unload(lambda: shutil.rmtree(TEMP_DIR, ignore_errors=True))

if __name__ == "__main__":
    demo.launch(debug=True)