YanmHa commited on
Commit
a13dcd6
·
verified ·
1 Parent(s): 1d37020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -99
app.py CHANGED
@@ -1,130 +1,156 @@
1
  import gradio as gr
 
2
  import os
3
- import shutil
4
  import zipfile
5
- import time # 引入 time 模块,用于在快速迭代时确保进度条有时间刷新
 
6
 
7
- # 1. 持久存储的挂载路径
 
8
  PERSISTENT_STORAGE_PATH = "/data"
9
 
10
- # 确保持久存储的根目录存在
 
 
 
 
 
 
 
 
 
 
11
  if not os.path.exists(PERSISTENT_STORAGE_PATH):
12
  os.makedirs(PERSISTENT_STORAGE_PATH, exist_ok=True)
13
- print(f"已创建持久存储目录: {PERSISTENT_STORAGE_PATH}")
14
  else:
15
- print(f"持久存储目录已找到: {PERSISTENT_STORAGE_PATH}")
 
16
 
17
- def upload_and_extract_with_progress(zip_file_object, progress: gr.Progress):
 
18
  """
19
- 处理上传的 .zip 文件,将其解压缩到持久存储,并更新进度条。
20
- Gradio 会自动将 gr.Progress() 对象注入到 'progress' 参数中。
21
  """
22
- if zip_file_object is None:
23
- return "错误:没有文件被上传。请选择一个 .zip 文件。"
24
-
25
- temp_zip_file_path = zip_file_object.name # 服务器上临时 .zip 文件的路径
26
-
27
- # 2. 定义解压缩的目标文件夹名称
28
- # 您可以基于上传的 .zip 文件名来命名,或者使用固定名称
29
- # target_folder_name = os.path.splitext(os.path.basename(zip_file_object.name))[0] # 使用zip文件名(无扩展名)
30
- target_folder_name = "my_extracted_content" # 或者一个固定的文件夹名
31
 
32
- full_extraction_path = os.path.join(PERSISTENT_STORAGE_PATH, target_folder_name)
 
 
 
 
33
 
34
  try:
35
- # 3. 确保存放解压缩内容的目标文件夹存在
36
- if not os.path.exists(full_extraction_path):
37
- os.makedirs(full_extraction_path, exist_ok=True)
38
- print(f"在持久存储中创建了目标文件夹: {full_extraction_path}")
39
- else:
40
- print(f"目标文件夹 '{full_extraction_path}' 已存在。文件将被解压缩到此位置(可能会覆盖)。")
41
- # 可选:如果文件夹已存在,您可以在这里添加逻辑来清空它或创建新的带时间戳的文件夹
42
- # shutil.rmtree(full_extraction_path)
43
- # os.makedirs(full_extraction_path, exist_ok=True)
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # 4. 解压缩过程并更新进度
47
- # Gradio 的文件上传组件本身会处理上传进度。
48
- # 此处的 progress 主要用于服务器端的解压过程。
49
- progress(0, desc="准备解压缩...") # 初始化进度条
50
-
51
- with zipfile.ZipFile(temp_zip_file_path, 'r') as zip_ref:
52
- file_list = zip_ref.infolist() # 获取压缩包内所有文件信息
53
- total_files = len(file_list)
54
-
55
- if total_files == 0:
56
- # os.remove(temp_zip_file_path) # 清理空的zip文件(可选)
57
- return f"成功上传 '{os.path.basename(zip_file_object.name)}',但压缩包为空。已将其放置在(或尝试放置在)'{temp_zip_file_path}'(临时)。目标解压路径:'{full_extraction_path}'。"
58
-
59
-
60
- # 迭代解压每个文件
61
  for i, member in enumerate(file_list):
62
- zip_ref.extract(member, path=full_extraction_path)
63
- # 更新进度:(当前步骤, 总步骤, 描述)
64
- # 为了防止UI更新过于频繁导致卡顿(特别是有成千上万个小文件时),
65
- # 可以选择不是每个文件都更新,或者在更新后短暂 sleep。
66
- # 但对于一般情况,每个文件更新一次是可接受的。
67
- progress( (i + 1) / total_files, desc=f"正在解压: {member.filename} ({i+1}/{total_files})")
68
- # 如果文件非常多,且单个文件解压飞快,可能需要 time.sleep(0.001) 来让UI有机会刷新,
69
- # 但通常Gradio的progress会自动处理好。
70
-
71
- # (可选)清理上传的临时 .zip 文件
72
- # try:
73
- # if os.path.exists(temp_zip_file_path):
74
- # os.remove(temp_zip_file_path)
75
- # print(f"已清理临时文件: {temp_zip_file_path}")
76
- # except Exception as e_remove:
77
- # print(f"清理临时文件 {temp_zip_file_path} 时出错: {e_remove}")
78
-
79
-
80
- return f"文件已成功上传并解压缩到 '{full_extraction_path}' 目录。共处理 {total_files} 个压缩包成员。"
81
-
82
- except zipfile.BadZipFile:
83
- # progress(1, desc="错误") # 更新进度条状态
84
- return "错��:上传的文件不是一个有效的 ZIP 文件。"
85
  except Exception as e:
86
- # progress(1, desc="错误") # 更新进度条状态
87
- error_message = f"处理上传和解压文件时发生错误: {str(e)}"
88
- print(error_message) # 打印到容器日志,方便调试
89
- return error_message
 
 
90
 
91
  # --- 创建 Gradio 应用界面 ---
92
  with gr.Blocks() as demo:
93
- gr.Markdown(f"## 带进度条:上传 .zip 并解压到 Space 的 `/data` 目录")
94
  gr.Markdown(
95
- f"请选择您本地的 `.zip` 文件进行上传。"
96
- f"文件上传本身会有浏览器/Gradio提供的进度指示。"
97
- f"上传完成后,服务器端的解压缩过程将通过下方的进度条显示。"
98
- f"最终文件将被解压缩到 Space 持久存储 `{PERSISTENT_STORAGE_PATH}` 下的一个子目录中。"
99
  )
100
 
101
- zip_file_uploader = gr.File(label="选择 .zip 文件", file_types=[".zip"])
102
- # 注意:Gradio 的 gr.File 组件在文件上传时,其UI本身会显示上传百分比。
103
- # 我们添加的 gr.Progress 主要用于指示服务器端解压的进度。
104
-
105
- process_button = gr.Button("上传并解压缩")
106
 
107
- status_output = gr.Textbox(label="处理状态和结果", lines=5, interactive=False)
108
- # gr.Progress() 组件不需要显式添加到布局中,当它作为参数传递给事件处理函数时,
109
- # Gradio 会自动在合适的位置(通常是触发按钮附近或全局)显示进度条。
110
-
111
- process_button.click(
112
- fn=upload_and_extract_with_progress, # 函数签名包含 progress: gr.Progress
113
- inputs=[zip_file_uploader], # zip_file_object 作为第一个参数
114
- outputs=[status_output] # 函数的返回值将更新这个文本框
115
- # Gradio 会自动将 gr.Progress() 实例传递给 fn 中的 progress 参数
116
  )
117
 
118
  if __name__ == "__main__":
119
- # 本地测试的注意事项
120
  if PERSISTENT_STORAGE_PATH == "/data" and not os.path.exists("/data"):
121
  print("警告:正在本地运行,且目标路径 /data 不存在。")
122
- print("如果您想在本地完整测试持久存储逻辑,请修改 PERSISTENT_STORAGE_PATH")
123
- print("为一个有效的本地目录,例如 './tmp_data/',并确保它存在或可以被创建。")
124
- # 例如,为本地测试创建一个模拟目录:
125
- # test_path = "./tmp_data_for_testing"
126
- # if not os.path.exists(test_path):
127
- # os.makedirs(test_path)
128
- # PERSISTENT_STORAGE_PATH = test_path
129
-
130
  demo.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download, HfFolder # HfFolder可以用于登录,但通常Space secrets更好
3
  import os
 
4
  import zipfile
5
+ import shutil # 用于清空目标文件夹(如果需要)
6
+ import traceback # 用于打印更详细的错误日志
7
 
8
+ # --- 配置 ---
9
+ # 1. 持久存储的挂载路径 (您已确认为 /data)
10
  PERSISTENT_STORAGE_PATH = "/data"
11
 
12
+ # 2. 存放 .zip 文件的 Hub 仓库信息
13
+ ZIP_FILE_REPO_ID = "YanmHa/image-aligned-experiment-data" # 这是您上传ZIP文件的仓库
14
+
15
+ # !!! 重要:请将下面的 "images.zip" 替换为您在 Hub 仓库中实际使用的确切文件名 !!!
16
+ ZIP_FILENAME_IN_REPO = "images.zip" # <--- 确保这是您上传到 YanmHa/image-aligned-experiment-data 的文件名
17
+
18
+ # 3. 解压缩后在 /data 中存放内容的子文件夹名称
19
+ EXTRACTION_SUBFOLDER_NAME = "images" # 您可以自定义
20
+
21
+
22
+ # --- 辅助函数:确保持久存储目录存在 ---
23
  if not os.path.exists(PERSISTENT_STORAGE_PATH):
24
  os.makedirs(PERSISTENT_STORAGE_PATH, exist_ok=True)
25
+ print(f"持久存储目录 '{PERSISTENT_STORAGE_PATH}' 已创建。")
26
  else:
27
+ print(f"持久存储目录 '{PERSISTENT_STORAGE_PATH}' 已找到。")
28
+
29
 
30
+ # --- Gradio 应用的核心逻辑 ---
31
+ def download_and_unzip_from_hub(progress: gr.Progress):
32
  """
33
+ Hugging Face Hub 下载指定的 .zip 文件到持久存储,然后解压缩。
34
+ Gradio 会自动注入 gr.Progress 对象。
35
  """
36
+ status_messages = [] # 用于收集操作过程中的信息
 
 
 
 
 
 
 
 
37
 
38
+ # 构造 .zip 文件在持久存储中的预期下载路径
39
+ local_zip_download_path = os.path.join(PERSISTENT_STORAGE_PATH, ZIP_FILENAME_IN_REPO)
40
+
41
+ # 构造解压缩的目标完整路径
42
+ full_extraction_target_path = os.path.join(PERSISTENT_STORAGE_PATH, EXTRACTION_SUBFOLDER_NAME)
43
 
44
  try:
45
+ # 阶段1: 下载文件
46
+ status_messages.append(f"阶段 1/2: 开始从 Hub 仓库 '{ZIP_FILE_REPO_ID}' 下载 '{ZIP_FILENAME_IN_REPO}'...")
47
+ progress(0.0, desc=status_messages[-1]) # 初始化进度
48
+
49
+ # 确保存放下载文件的目录存在
50
+ if not os.path.exists(PERSISTENT_STORAGE_PATH):
51
+ os.makedirs(PERSISTENT_STORAGE_PATH, exist_ok=True)
52
+
53
+ # 使用 hf_hub_download 下载文件
54
+ # repo_type 需要根据您在 Hub 上创建仓库时的类型来指定。
55
+ # 如果是 Dataset 仓库,使用 repo_type="dataset"。如果是 Model 仓库,可以省略或用 repo_type="model"。
56
+ actual_downloaded_file_path = hf_hub_download(
57
+ repo_id=ZIP_FILE_REPO_ID,
58
+ filename=ZIP_FILENAME_IN_REPO,
59
+ repo_type="dataset", # <--- 如果您创建的是Dataset仓库,请保留或修改此项
60
+ token=None, # 公开仓库不需要token。私有仓库需要配置Space Secret: HUGGING_FACE_HUB_TOKEN
61
+ cache_dir=os.path.join(PERSISTENT_STORAGE_PATH, ".cache_hf_hub"), # 将缓存也放在持久存储中
62
+ local_dir=PERSISTENT_STORAGE_PATH, # 尝试让它直接下载到 /data
63
+ local_dir_use_symlinks=False # 确保文件实际复制,而不是符号链接
64
+ )
65
+
66
+ # 验证文件是否下载到了期望的 local_zip_download_path
67
+ if actual_downloaded_file_path != local_zip_download_path:
68
+ status_messages.append(f"注意:文件下载到: {actual_downloaded_file_path}。期望路径: {local_zip_download_path}。")
69
+ status_messages.append(f"如果路径不一致且文件未在期望路径,请检查hf_hub_download的local_dir行为或手动移动。")
70
+ # 为确保后续操作使用正确路径,我们以 local_zip_download_path 为准,假设文件已在该位置
71
+ if not os.path.exists(local_zip_download_path) and os.path.exists(actual_downloaded_file_path):
72
+ status_messages.append(f"尝试将文件从 {actual_downloaded_file_path} 移动到 {local_zip_download_path}")
73
+ shutil.move(actual_downloaded_file_path, local_zip_download_path)
74
+
75
+
76
+ if not os.path.exists(local_zip_download_path):
77
+ raise FileNotFoundError(f"文件下载后未在期望路径 {local_zip_download_path} 找到。实际下载路径: {actual_downloaded_file_path}")
78
+
79
+ status_messages.append(f"文件 '{ZIP_FILENAME_IN_REPO}' 下载完成,位于: {local_zip_download_path}")
80
+ progress(0.5, desc="下载完成,准备解压...") # 下载占总进度的50%
81
+
82
+ # 阶段 2: 解压缩文件
83
+ status_messages.append(f"阶段 2/2: 开始解压缩 '{ZIP_FILENAME_IN_REPO}' 到 '{full_extraction_target_path}'...")
84
+ progress(0.5, desc=status_messages[-1]) # 解压从总进度的50%开始
85
+
86
+ # (可选) 如果目标解压文件夹已存在,可以选择先清空
87
+ if os.path.exists(full_extraction_target_path):
88
+ status_messages.append(f"目标文件夹 '{full_extraction_target_path}' 已存在,正在清空...")
89
+ progress(0.55, desc=status_messages[-1])
90
+ try:
91
+ shutil.rmtree(full_extraction_target_path)
92
+ status_messages.append(f"旧文件夹 '{full_extraction_target_path}' 已清空。")
93
+ except Exception as e_rm:
94
+ status_messages.append(f"警告:清空旧文件夹 '{full_extraction_target_path}' 时出错: {e_rm}")
95
+
96
+ if not os.path.exists(full_extraction_target_path):
97
+ os.makedirs(full_extraction_target_path, exist_ok=True)
98
+
99
+ with zipfile.ZipFile(local_zip_download_path, 'r') as zip_ref:
100
+ file_list = zip_ref.infolist()
101
+ total_files_to_extract = len(file_list)
102
+
103
+ if total_files_to_extract == 0:
104
+ status_messages.append("下载的 .zip 文件为空。")
105
+ progress(1.0, desc="压缩包为空")
106
+ # (可选) 删除空的 .zip 文件
107
+ # os.remove(local_zip_download_path)
108
+ return "\n".join(status_messages)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  for i, member in enumerate(file_list):
111
+ zip_ref.extract(member, path=full_extraction_target_path)
112
+ current_progress_overall = 0.5 + (((i + 1) / total_files_to_extract) * 0.5)
113
+ progress(current_progress_overall, desc=f"正在解压: {member.filename} ({i+1}/{total_files_to_extract})")
114
+
115
+ status_messages.append(f"文件已成功解压缩到 '{full_extraction_target_path}'。共处理 {total_files_to_extract} 个压缩包成员。")
116
+ progress(1.0, desc="全部完成!")
117
+
118
+ # (可选) 解压完成后删除 .zip 文件以节省 /data 空间
119
+ try:
120
+ os.remove(local_zip_download_path)
121
+ status_messages.append(f"已删除下载的 .zip 文件: {local_zip_download_path}")
122
+ except Exception as e_del_zip:
123
+ status_messages.append(f"警告:删除 .zip 文件 {local_zip_download_path} 时出错: {e_del_zip}")
124
+
125
+ return "\n".join(status_messages)
126
+
 
 
 
 
 
 
 
127
  except Exception as e:
128
+ error_msg = f"处理过程中发生严重错误: {str(e)}"
129
+ status_messages.append(error_msg)
130
+ print(f"ERROR in download_and_unzip_from_hub: {error_msg}\n{traceback.format_exc()}") # 打印完整错误到容器日志
131
+ progress(1.0, desc="发生错误!")
132
+ return "\n".join(status_messages)
133
+
134
 
135
  # --- 创建 Gradio 应用界面 ---
136
  with gr.Blocks() as demo:
 
137
  gr.Markdown(
138
+ f"## Hugging Face Hub 同步数据\n"
139
+ f"点击按钮将从 Hub 仓库 `{ZIP_FILE_REPO_ID}` 下载文件 `{ZIP_FILENAME_IN_REPO}` "
140
+ f"到 Space 的持久存储 `{PERSISTENT_STORAGE_PATH}` 目录(通常是 `/data`),并将其解压缩到子目录 `{EXTRACTION_SUBFOLDER_NAME}` 中。"
 
141
  )
142
 
143
+ sync_button = gr.Button(f"开始从 Hub 下载并解压缩 '{ZIP_FILENAME_IN_REPO}'")
 
 
 
 
144
 
145
+ status_display = gr.Markdown("点击上方按钮开始操作...") # 使用 Markdown 以支持多行和换行
146
+
147
+ sync_button.click(
148
+ fn=download_and_unzip_from_hub,
149
+ inputs=None,
150
+ outputs=[status_display]
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
 
154
  if PERSISTENT_STORAGE_PATH == "/data" and not os.path.exists("/data"):
155
  print("警告:正在本地运行,且目标路径 /data 不存在。")
 
 
 
 
 
 
 
 
156
  demo.launch()