illuminati360 commited on
Commit
6b8a3ff
·
1 Parent(s): 79758ec

Add application file

Browse files
Files changed (2) hide show
  1. app.py +92 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import duckdb
4
+ import requests
5
+
6
+ DATASET_REPO = "gmongaras/Imagenet21K"
7
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
8
+
9
+ def get_parquet_urls():
10
+ """获取所有 parquet 文件 URL"""
11
+ api_url = f"https://huggingface.co/api/datasets/{DATASET_REPO}/parquet/default/train"
12
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
13
+
14
+ response = requests.get(api_url, headers=headers)
15
+ response.raise_for_status()
16
+
17
+ data = response.json()
18
+ urls = []
19
+
20
+ items = data if isinstance(data, list) else data.get("parquet_files", [])
21
+ for item in items:
22
+ if isinstance(item, str):
23
+ urls.append(item)
24
+ else:
25
+ url = item.get("url")
26
+ if url:
27
+ urls.append(url)
28
+
29
+ return urls
30
+
31
+ def build_ids_duckdb(progress=gr.Progress()):
32
+ """使用 DuckDB 直接查询所有 Parquet 文件"""
33
+
34
+ progress(0, desc="获取 Parquet 文件列表...")
35
+ urls = get_parquet_urls()
36
+
37
+ progress(0.1, desc=f"找到 {len(urls)} 个文件,开始提取...")
38
+
39
+ con = duckdb.connect()
40
+ con.execute("INSTALL httpfs; LOAD httpfs;")
41
+
42
+ # 设置认证
43
+ if HF_TOKEN:
44
+ con.execute(f"SET httpfs_custom_header='Authorization: Bearer {HF_TOKEN}';")
45
+
46
+ progress(0.2, desc="执行 SQL 查询...")
47
+
48
+ # 构建文件列表
49
+ files_literal = ",".join([f"'{url}'" for url in urls])
50
+
51
+ # 一次性查询所有文件的 id 列
52
+ query = f"""
53
+ COPY (
54
+ SELECT id
55
+ FROM parquet_scan([{files_literal}])
56
+ ) TO 'ids.parquet' (FORMAT 'parquet', COMPRESSION 'zstd');
57
+ """
58
+
59
+ progress(0.5, desc="正在执行大规模查询...")
60
+ con.execute(query)
61
+
62
+ progress(1.0, desc="完成!")
63
+
64
+ file_size = os.path.getsize('ids.parquet') / 1024 / 1024
65
+ return f"提取完成!\n文件大小: {file_size:.1f} MB"
66
+
67
+ def ui_build_ids(progress=gr.Progress()):
68
+ try:
69
+ result = build_ids_duckdb(progress=progress)
70
+ return result, "ids.parquet"
71
+ except Exception as e:
72
+ return f"错误: {e}", None
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("# ImageNet21K ID 提取器 (DuckDB)\n使用 DuckDB 快速提取 ID 列")
76
+
77
+ run_btn = gr.Button("开始提取", variant="primary")
78
+
79
+ with gr.Row():
80
+ log = gr.Textbox(label="状态", lines=5)
81
+ download = gr.File(label="下载文件")
82
+
83
+ run_btn.click(
84
+ ui_build_ids,
85
+ outputs=[log, download],
86
+ show_progress=True
87
+ )
88
+
89
+ demo.queue()
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ requests
3
+ duckdb