Spaces:
Sleeping
Sleeping
File size: 5,449 Bytes
6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 40d0e50 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff 1329557 6b8a3ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import gradio as gr
import duckdb
import requests
import threading
import time
import json
from datetime import datetime
DATASET_REPO = "gmongaras/Imagenet21K"
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
# 全局状态文件
STATUS_FILE = "extraction_status.json"
OUTPUT_FILE = "ids.parquet"
def get_parquet_urls():
"""获取所有 parquet 文件 URL"""
api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
response = requests.get(api_url, headers=headers)
response.raise_for_status()
data = response.json()
urls = []
items = data if isinstance(data, list) else data.get("parquet_files", [])
for item in items:
if isinstance(item, str):
urls.append(item)
else:
url = item.get("url")
if url:
urls.append(url)
return urls
def save_status(status, message, progress=0.0, error=None):
"""保存状态到文件"""
status_data = {
"status": status, # "running", "completed", "error", "idle"
"message": message,
"progress": progress,
"timestamp": datetime.now().isoformat(),
"error": error
}
with open(STATUS_FILE, 'w') as f:
json.dump(status_data, f)
def load_status():
"""从文件加载状态"""
if os.path.exists(STATUS_FILE):
try:
with open(STATUS_FILE, 'r') as f:
return json.load(f)
except:
pass
return {
"status": "idle",
"message": "没有运行中的任务",
"progress": 0.0,
"timestamp": "",
"error": None
}
def build_ids_background():
"""后台运行 DuckDB 提取任务"""
try:
save_status("running", "获取 Parquet 文件列表...", 0.0)
urls = get_parquet_urls()
save_status("running", f"找到 {len(urls)} 个文件,开始提取...", 0.1)
con = duckdb.connect()
con.execute("INSTALL httpfs; LOAD httpfs;")
# 设置认证和优化
if HF_TOKEN:
con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';")
con.execute("SET http_keep_alive=true;")
con.execute("SET enable_object_cache=true;")
con.execute("SET threads=4;")
save_status("running", "执行 SQL 查询...", 0.2)
# 构建文件列表
files_literal = ",".join([f"'{url}'" for url in urls])
# 一次性查询所有文件的 id 列
query = f"""
COPY (
SELECT id
FROM parquet_scan([{files_literal}])
) TO '{OUTPUT_FILE}' (FORMAT 'parquet', COMPRESSION 'zstd');
"""
save_status("running", "正在执行大规模查询...", 0.5)
con.execute(query)
# 完成
file_size = os.path.getsize(OUTPUT_FILE) / 1024 / 1024
save_status("completed", f"提取完成!文件大小: {file_size:.1f} MB", 1.0)
except Exception as e:
save_status("error", "提取失败", 0.0, str(e))
def start_extraction():
"""启动后台提取任务"""
status = load_status()
if status["status"] == "running":
return "任务已在运行中..."
# 删除旧的输出文件
if os.path.exists(OUTPUT_FILE):
os.remove(OUTPUT_FILE)
# 启动后台线程
thread = threading.Thread(target=build_ids_background, daemon=True)
thread.start()
return "后台任务已启动!刷新页面不会中断任务。"
def check_status():
"""检查当前状态"""
status = load_status()
progress_text = f"{status['progress']*100:.1f}%" if status['progress'] > 0 else ""
timestamp = status.get('timestamp', '')
if status['status'] == 'completed':
if os.path.exists(OUTPUT_FILE):
return f"✅ {status['message']}\n时间: {timestamp}\n进度: {progress_text}", OUTPUT_FILE
else:
return f"⚠️ 任务完成但文件不存在\n时间: {timestamp}", None
elif status['status'] == 'error':
error_msg = status.get('error', '未知错误')
return f"❌ {status['message']}\n错误: {error_msg}\n时间: {timestamp}", None
elif status['status'] == 'running':
return f"🔄 {status['message']}\n进度: {progress_text}\n时间: {timestamp}", None
else:
return status['message'], None
with gr.Blocks() as demo:
gr.Markdown("# ImageNet21K ID 提取器 (后台运行)\n✨ 支持浏览器断线重连,任务在后台持续运行")
with gr.Row():
start_btn = gr.Button("开始提取", variant="primary")
refresh_btn = gr.Button("刷新状态", variant="secondary")
with gr.Row():
status_text = gr.Textbox(label="任务状态", lines=5)
download_file = gr.File(label="下载文件")
# 自动刷新状态 (每10秒)
auto_refresh = gr.Textbox(visible=False)
start_btn.click(start_extraction, outputs=[status_text])
refresh_btn.click(check_status, outputs=[status_text, download_file])
# 页面加载时检查状态
demo.load(check_status, outputs=[status_text, download_file])
demo.queue()
if __name__ == "__main__":
demo.launch() |