Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import duckdb | |
| import requests | |
| import threading | |
| import time | |
| import json | |
| from datetime import datetime | |
| DATASET_REPO = "gmongaras/Imagenet21K" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| # 全局状态文件 | |
| STATUS_FILE = "extraction_status.json" | |
| OUTPUT_FILE = "ids.parquet" | |
| def get_parquet_urls(): | |
| """获取所有 parquet 文件 URL""" | |
| api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train" | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
| response = requests.get(api_url, headers=headers) | |
| response.raise_for_status() | |
| data = response.json() | |
| urls = [] | |
| items = data if isinstance(data, list) else data.get("parquet_files", []) | |
| for item in items: | |
| if isinstance(item, str): | |
| urls.append(item) | |
| else: | |
| url = item.get("url") | |
| if url: | |
| urls.append(url) | |
| return urls | |
| def save_status(status, message, progress=0.0, error=None): | |
| """保存状态到文件""" | |
| status_data = { | |
| "status": status, # "running", "completed", "error", "idle" | |
| "message": message, | |
| "progress": progress, | |
| "timestamp": datetime.now().isoformat(), | |
| "error": error | |
| } | |
| with open(STATUS_FILE, 'w') as f: | |
| json.dump(status_data, f) | |
| def load_status(): | |
| """从文件加载状态""" | |
| if os.path.exists(STATUS_FILE): | |
| try: | |
| with open(STATUS_FILE, 'r') as f: | |
| return json.load(f) | |
| except: | |
| pass | |
| return { | |
| "status": "idle", | |
| "message": "没有运行中的任务", | |
| "progress": 0.0, | |
| "timestamp": "", | |
| "error": None | |
| } | |
| def build_ids_background(): | |
| """后台运行 DuckDB 提取任务""" | |
| try: | |
| save_status("running", "获取 Parquet 文件列表...", 0.0) | |
| urls = get_parquet_urls() | |
| save_status("running", f"找到 {len(urls)} 个文件,开始提取...", 0.1) | |
| con = duckdb.connect() | |
| con.execute("INSTALL httpfs; LOAD httpfs;") | |
| # 设置认证和优化 | |
| if HF_TOKEN: | |
| con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';") | |
| con.execute("SET http_keep_alive=true;") | |
| con.execute("SET enable_object_cache=true;") | |
| con.execute("SET threads=4;") | |
| save_status("running", "执行 SQL 查询...", 0.2) | |
| # 构建文件列表 | |
| files_literal = ",".join([f"'{url}'" for url in urls]) | |
| # 一次性查询所有文件的 id 列 | |
| query = f""" | |
| COPY ( | |
| SELECT id | |
| FROM parquet_scan([{files_literal}]) | |
| ) TO '{OUTPUT_FILE}' (FORMAT 'parquet', COMPRESSION 'zstd'); | |
| """ | |
| save_status("running", "正在执行大规模查询...", 0.5) | |
| con.execute(query) | |
| # 完成 | |
| file_size = os.path.getsize(OUTPUT_FILE) / 1024 / 1024 | |
| save_status("completed", f"提取完成!文件大小: {file_size:.1f} MB", 1.0) | |
| except Exception as e: | |
| save_status("error", "提取失败", 0.0, str(e)) | |
| def start_extraction(): | |
| """启动后台提取任务""" | |
| status = load_status() | |
| if status["status"] == "running": | |
| return "任务已在运行中..." | |
| # 删除旧的输出文件 | |
| if os.path.exists(OUTPUT_FILE): | |
| os.remove(OUTPUT_FILE) | |
| # 启动后台线程 | |
| thread = threading.Thread(target=build_ids_background, daemon=True) | |
| thread.start() | |
| return "后台任务已启动!刷新页面不会中断任务。" | |
| def check_status(): | |
| """检查当前状态""" | |
| status = load_status() | |
| progress_text = f"{status['progress']*100:.1f}%" if status['progress'] > 0 else "" | |
| timestamp = status.get('timestamp', '') | |
| if status['status'] == 'completed': | |
| if os.path.exists(OUTPUT_FILE): | |
| return f"✅ {status['message']}\n时间: {timestamp}\n进度: {progress_text}", OUTPUT_FILE | |
| else: | |
| return f"⚠️ 任务完成但文件不存在\n时间: {timestamp}", None | |
| elif status['status'] == 'error': | |
| error_msg = status.get('error', '未知错误') | |
| return f"❌ {status['message']}\n错误: {error_msg}\n时间: {timestamp}", None | |
| elif status['status'] == 'running': | |
| return f"🔄 {status['message']}\n进度: {progress_text}\n时间: {timestamp}", None | |
| else: | |
| return status['message'], None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ImageNet21K ID 提取器 (后台运行)\n✨ 支持浏览器断线重连,任务在后台持续运行") | |
| with gr.Row(): | |
| start_btn = gr.Button("开始提取", variant="primary") | |
| refresh_btn = gr.Button("刷新状态", variant="secondary") | |
| with gr.Row(): | |
| status_text = gr.Textbox(label="任务状态", lines=5) | |
| download_file = gr.File(label="下载文件") | |
| # 自动刷新状态 (每10秒) | |
| auto_refresh = gr.Textbox(visible=False) | |
| start_btn.click(start_extraction, outputs=[status_text]) | |
| refresh_btn.click(check_status, outputs=[status_text, download_file]) | |
| # 页面加载时检查状态 | |
| demo.load(check_status, outputs=[status_text, download_file]) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |