illuminati360 commited on
Commit
1329557
·
1 Parent(s): 6b8a3ff
Files changed (1) hide show
  1. app.py +125 -43
app.py CHANGED
@@ -2,10 +2,18 @@ import os
2
  import gradio as gr
3
  import duckdb
4
  import requests
 
 
 
 
5
 
6
  DATASET_REPO = "gmongaras/Imagenet21K"
7
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
8
 
 
 
 
 
9
  def get_parquet_urls():
10
  """获取所有 parquet 文件 URL"""
11
  api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train"
@@ -28,63 +36,137 @@ def get_parquet_urls():
28
 
29
  return urls
30
 
31
- def build_ids_duckdb(progress=gr.Progress()):
32
- """使用 DuckDB 直接查询所有 Parquet 文件"""
33
-
34
- progress(0, desc="获取 Parquet 文件列表...")
35
- urls = get_parquet_urls()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- progress(0.1, desc=f"找到 {len(urls)} 个文件,开始提取...")
 
38
 
39
- con = duckdb.connect()
40
- con.execute("INSTALL httpfs; LOAD httpfs;")
 
41
 
42
- # 设置认证
43
- if HF_TOKEN:
44
- con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';")
45
 
46
- progress(0.2, desc="执行 SQL 查询...")
 
 
 
 
47
 
48
- # 构建文件列表
49
- files_literal = ",".join([f"'{url}'" for url in urls])
50
 
51
- # 一次性查询所有文件的 id
52
- query = f"""
53
- COPY (
54
- SELECT id
55
- FROM parquet_scan([{files_literal}])
56
- ) TO 'ids.parquet' (FORMAT 'parquet', COMPRESSION 'zstd');
57
- """
58
 
59
- progress(0.5, desc="正在执行大规模查询...")
60
- con.execute(query)
 
61
 
62
- progress(1.0, desc="完成!")
 
63
 
64
- file_size = os.path.getsize('ids.parquet') / 1024 / 1024
65
- return f"提取完成!\n文件大小: {file_size:.1f} MB"
66
-
67
- def ui_build_ids(progress=gr.Progress()):
68
- try:
69
- result = build_ids_duckdb(progress=progress)
70
- return result, "ids.parquet"
71
- except Exception as e:
72
- return f"错误: {e}", None
73
 
74
  with gr.Blocks() as demo:
75
- gr.Markdown("# ImageNet21K ID 提取器 (DuckDB)\n使用 DuckDB 快速提取 ID 列")
76
 
77
- run_btn = gr.Button("开始提取", variant="primary")
 
 
78
 
79
  with gr.Row():
80
- log = gr.Textbox(label="状态", lines=5)
81
- download = gr.File(label="下载文件")
82
-
83
- run_btn.click(
84
- ui_build_ids,
85
- outputs=[log, download],
86
- show_progress=True
87
- )
 
 
 
88
 
89
  demo.queue()
90
 
 
2
  import gradio as gr
3
  import duckdb
4
  import requests
5
+ import threading
6
+ import time
7
+ import json
8
+ from datetime import datetime
9
 
10
  DATASET_REPO = "gmongaras/Imagenet21K"
11
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
12
 
13
+ # 全局状态文件
14
+ STATUS_FILE = "extraction_status.json"
15
+ OUTPUT_FILE = "ids.parquet"
16
+
17
  def get_parquet_urls():
18
  """获取所有 parquet 文件 URL"""
19
  api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train"
 
36
 
37
  return urls
38
 
39
+ def save_status(status, message, progress=0.0, error=None):
40
+ """保存状态到文件"""
41
+ status_data = {
42
+ "status": status, # "running", "completed", "error", "idle"
43
+ "message": message,
44
+ "progress": progress,
45
+ "timestamp": datetime.now().isoformat(),
46
+ "error": error
47
+ }
48
+
49
+ with open(STATUS_FILE, 'w') as f:
50
+ json.dump(status_data, f)
51
+
52
+ def load_status():
53
+ """从文件加载状态"""
54
+ if os.path.exists(STATUS_FILE):
55
+ try:
56
+ with open(STATUS_FILE, 'r') as f:
57
+ return json.load(f)
58
+ except:
59
+ pass
60
+
61
+ return {
62
+ "status": "idle",
63
+ "message": "没有运行中的任务",
64
+ "progress": 0.0,
65
+ "timestamp": "",
66
+ "error": None
67
+ }
68
+
69
+ def build_ids_background():
70
+ """后台运行 DuckDB 提取任务"""
71
+ try:
72
+ save_status("running", "获取 Parquet 文件列表...", 0.0)
73
+
74
+ urls = get_parquet_urls()
75
+ save_status("running", f"找到 {len(urls)} 个文件,开始提取...", 0.1)
76
+
77
+ con = duckdb.connect()
78
+ con.execute("INSTALL httpfs; LOAD httpfs;")
79
+
80
+ # 设置认证和优化
81
+ if HF_TOKEN:
82
+ con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';")
83
+
84
+ con.execute("SET httpfs_keep_alive=true;")
85
+ con.execute("SET enable_object_cache=true;")
86
+ con.execute("SET threads=4;")
87
+
88
+ save_status("running", "执行 SQL 查询...", 0.2)
89
+
90
+ # 构建文件列表
91
+ files_literal = ",".join([f"'{url}'" for url in urls])
92
+
93
+ # 一次性查询所有文件的 id 列
94
+ query = f"""
95
+ COPY (
96
+ SELECT id
97
+ FROM parquet_scan([{files_literal}])
98
+ ) TO '{OUTPUT_FILE}' (FORMAT 'parquet', COMPRESSION 'zstd');
99
+ """
100
+
101
+ save_status("running", "正在执行大规模查询...", 0.5)
102
+ con.execute(query)
103
+
104
+ # 完成
105
+ file_size = os.path.getsize(OUTPUT_FILE) / 1024 / 1024
106
+ save_status("completed", f"提取完成!文件大小: {file_size:.1f} MB", 1.0)
107
+
108
+ except Exception as e:
109
+ save_status("error", "提取失败", 0.0, str(e))
110
+
111
+ def start_extraction():
112
+ """启动后台提取任务"""
113
+ status = load_status()
114
 
115
+ if status["status"] == "running":
116
+ return "任务已在运行中..."
117
 
118
+ # 删除旧的输出文件
119
+ if os.path.exists(OUTPUT_FILE):
120
+ os.remove(OUTPUT_FILE)
121
 
122
+ # 启动后台线程
123
+ thread = threading.Thread(target=build_ids_background, daemon=True)
124
+ thread.start()
125
 
126
+ return "后台任务已启动!刷新页面不会中断任务。"
127
+
128
+ def check_status():
129
+ """检查当前状态"""
130
+ status = load_status()
131
 
132
+ progress_text = f"{status['progress']*100:.1f}%" if status['progress'] > 0 else ""
133
+ timestamp = status.get('timestamp', '')
134
 
135
+ if status['status'] == 'completed':
136
+ if os.path.exists(OUTPUT_FILE):
137
+ return f"✅ {status['message']}\n时间: {timestamp}\n进度: {progress_text}", OUTPUT_FILE
138
+ else:
139
+ return f"⚠️ 任务完成但文件不存在\n时间: {timestamp}", None
 
 
140
 
141
+ elif status['status'] == 'error':
142
+ error_msg = status.get('error', '未知错误')
143
+ return f"❌ {status['message']}\n错误: {error_msg}\n时间: {timestamp}", None
144
 
145
+ elif status['status'] == 'running':
146
+ return f"🔄 {status['message']}\n进度: {progress_text}\n时间: {timestamp}", None
147
 
148
+ else:
149
+ return status['message'], None
 
 
 
 
 
 
 
150
 
151
  with gr.Blocks() as demo:
152
+ gr.Markdown("# ImageNet21K ID 提取器 (后台运行)\n 支持浏览器断线重连,任务在后台持续运行")
153
 
154
+ with gr.Row():
155
+ start_btn = gr.Button("开始提取", variant="primary")
156
+ refresh_btn = gr.Button("刷新状态", variant="secondary")
157
 
158
  with gr.Row():
159
+ status_text = gr.Textbox(label="任务状态", lines=5)
160
+ download_file = gr.File(label="下载文件")
161
+
162
+ # 自动刷新状态 (每10秒)
163
+ auto_refresh = gr.Textbox(visible=False)
164
+
165
+ start_btn.click(start_extraction, outputs=[status_text])
166
+ refresh_btn.click(check_status, outputs=[status_text, download_file])
167
+
168
+ # 页面加载时检查状态
169
+ demo.load(check_status, outputs=[status_text, download_file])
170
 
171
  demo.queue()
172