File size: 5,449 Bytes
6b8a3ff
 
 
 
1329557
 
 
 
6b8a3ff
 
 
 
1329557
 
 
 
6b8a3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1329557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d0e50
1329557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b8a3ff
1329557
 
6b8a3ff
1329557
 
 
6b8a3ff
1329557
 
 
6b8a3ff
1329557
 
 
 
 
6b8a3ff
1329557
 
6b8a3ff
1329557
 
 
 
 
6b8a3ff
1329557
 
 
6b8a3ff
1329557
 
6b8a3ff
1329557
 
6b8a3ff
 
1329557
6b8a3ff
1329557
 
 
6b8a3ff
 
1329557
 
 
 
 
 
 
 
 
 
 
6b8a3ff
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import gradio as gr
import duckdb
import requests
import threading
import time
import json
from datetime import datetime

DATASET_REPO = "gmongaras/Imagenet21K"
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")

# 全局状态文件
STATUS_FILE = "extraction_status.json"
OUTPUT_FILE = "ids.parquet"

def get_parquet_urls():
    """获取所有 parquet 文件 URL"""
    api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train"
    headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
    
    response = requests.get(api_url, headers=headers)
    response.raise_for_status()
    
    data = response.json()
    urls = []
    
    items = data if isinstance(data, list) else data.get("parquet_files", [])
    for item in items:
        if isinstance(item, str):
            urls.append(item)
        else:
            url = item.get("url")
            if url:
                urls.append(url)
    
    return urls

def save_status(status, message, progress=0.0, error=None):
    """保存状态到文件"""
    status_data = {
        "status": status,  # "running", "completed", "error", "idle"
        "message": message,
        "progress": progress,
        "timestamp": datetime.now().isoformat(),
        "error": error
    }
    
    with open(STATUS_FILE, 'w') as f:
        json.dump(status_data, f)

def load_status():
    """从文件加载状态"""
    if os.path.exists(STATUS_FILE):
        try:
            with open(STATUS_FILE, 'r') as f:
                return json.load(f)
        except:
            pass
    
    return {
        "status": "idle",
        "message": "没有运行中的任务",
        "progress": 0.0,
        "timestamp": "",
        "error": None
    }

def build_ids_background():
    """后台运行 DuckDB 提取任务"""
    try:
        save_status("running", "获取 Parquet 文件列表...", 0.0)
        
        urls = get_parquet_urls()
        save_status("running", f"找到 {len(urls)} 个文件,开始提取...", 0.1)
        
        con = duckdb.connect()
        con.execute("INSTALL httpfs; LOAD httpfs;")
        
        # 设置认证和优化
        if HF_TOKEN:
            con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';")
        
        con.execute("SET http_keep_alive=true;")
        con.execute("SET enable_object_cache=true;")
        con.execute("SET threads=4;")
        
        save_status("running", "执行 SQL 查询...", 0.2)
        
        # 构建文件列表
        files_literal = ",".join([f"'{url}'" for url in urls])
        
        # 一次性查询所有文件的 id 列
        query = f"""
        COPY (
            SELECT id 
            FROM parquet_scan([{files_literal}])
        ) TO '{OUTPUT_FILE}' (FORMAT 'parquet', COMPRESSION 'zstd');
        """
        
        save_status("running", "正在执行大规模查询...", 0.5)
        con.execute(query)
        
        # 完成
        file_size = os.path.getsize(OUTPUT_FILE) / 1024 / 1024
        save_status("completed", f"提取完成!文件大小: {file_size:.1f} MB", 1.0)
        
    except Exception as e:
        save_status("error", "提取失败", 0.0, str(e))

def start_extraction():
    """启动后台提取任务"""
    status = load_status()
    
    if status["status"] == "running":
        return "任务已在运行中..."
    
    # 删除旧的输出文件
    if os.path.exists(OUTPUT_FILE):
        os.remove(OUTPUT_FILE)
    
    # 启动后台线程
    thread = threading.Thread(target=build_ids_background, daemon=True)
    thread.start()
    
    return "后台任务已启动!刷新页面不会中断任务。"

def check_status():
    """检查当前状态"""
    status = load_status()
    
    progress_text = f"{status['progress']*100:.1f}%" if status['progress'] > 0 else ""
    timestamp = status.get('timestamp', '')
    
    if status['status'] == 'completed':
        if os.path.exists(OUTPUT_FILE):
            return f"✅ {status['message']}\n时间: {timestamp}\n进度: {progress_text}", OUTPUT_FILE
        else:
            return f"⚠️ 任务完成但文件不存在\n时间: {timestamp}", None
    
    elif status['status'] == 'error':
        error_msg = status.get('error', '未知错误')
        return f"❌ {status['message']}\n错误: {error_msg}\n时间: {timestamp}", None
    
    elif status['status'] == 'running':
        return f"🔄 {status['message']}\n进度: {progress_text}\n时间: {timestamp}", None
    
    else:
        return status['message'], None

with gr.Blocks() as demo:
    gr.Markdown("# ImageNet21K ID 提取器 (后台运行)\n✨ 支持浏览器断线重连,任务在后台持续运行")
    
    with gr.Row():
        start_btn = gr.Button("开始提取", variant="primary")
        refresh_btn = gr.Button("刷新状态", variant="secondary")
    
    with gr.Row():
        status_text = gr.Textbox(label="任务状态", lines=5)
        download_file = gr.File(label="下载文件")
    
    # 自动刷新状态 (每10秒)
    auto_refresh = gr.Textbox(visible=False)
    
    start_btn.click(start_extraction, outputs=[status_text])
    refresh_btn.click(check_status, outputs=[status_text, download_file])
    
    # 页面加载时检查状态
    demo.load(check_status, outputs=[status_text, download_file])

demo.queue()

if __name__ == "__main__":
    demo.launch()