mmrech commited on
Commit
69375af
·
verified ·
1 Parent(s): 2e3edd5

Upload create_image_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. create_image_dataset.py +206 -0
create_image_dataset.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "torch",
6
+ # "datasets>=2.18.0",
7
+ # "pillow",
8
+ # "opencv-python-headless",
9
+ # "huggingface_hub>=0.21.0",
10
+ # "av",
11
+ # "tqdm",
12
+ # ]
13
+ # ///
14
+ """
15
+ Create dataset with embedded images from pitvqa-comprehensive-spatial.
16
+
17
+ Extracts video frames and embeds them directly in the dataset.
18
+ This eliminates the need for video extraction during training/inference.
19
+
20
+ Run with: hf jobs uv run --flavor cpu-xlarge --secrets HF_TOKEN create_image_dataset.py
21
+ """
22
+
23
+ import os
24
+ import cv2
25
+ from io import BytesIO
26
+ from PIL import Image
27
+ from pathlib import Path
28
+ from tqdm import tqdm
29
+
30
+ # ============================================================
31
+ # Config
32
+ # ============================================================
33
+
34
+ SOURCE_DATASET = "mmrech/pitvqa-comprehensive-spatial"
35
+ VIDEO_DATASET = "UCL-WEISS/PitVis-2023"
36
+ OUTPUT_DATASET = "mmrech/pitvqa-spatial-with-images"
37
+
38
+ VIDEO_CACHE = Path("/tmp/videos")
39
+ VIDEO_CACHE.mkdir(exist_ok=True)
40
+
41
+ MAX_SAMPLES = 1000 # Start with subset for testing
42
+
43
+ # ============================================================
44
+ # Setup
45
+ # ============================================================
46
+
47
+ from huggingface_hub import login, HfApi, hf_hub_download
48
+ from datasets import load_dataset, Dataset, Features, Value, Image as ImageFeature
49
+
50
+ hf_token = os.environ.get("HF_TOKEN")
51
+ if hf_token:
52
+ login(token=hf_token)
53
+ print("✓ Logged in to HuggingFace")
54
+
55
+ api = HfApi()
56
+
57
+ # ============================================================
58
+ # Load Source Dataset
59
+ # ============================================================
60
+
61
+ print("\n📦 Loading source dataset...")
62
+ ds = load_dataset(SOURCE_DATASET, split="train")
63
+ print(f"✓ Loaded {len(ds)} samples")
64
+
65
+ # ============================================================
66
+ # Video Helpers
67
+ # ============================================================
68
+
69
+ video_cache = {}
70
+
71
+ def download_video(video_id: str) -> Path:
72
+ """Download video if not cached."""
73
+ video_path = VIDEO_CACHE / f"{video_id}.mp4"
74
+ if not video_path.exists():
75
+ try:
76
+ downloaded = hf_hub_download(
77
+ repo_id=VIDEO_DATASET,
78
+ filename=f"videos/{video_id}.mp4",
79
+ repo_type="dataset"
80
+ )
81
+ import shutil
82
+ shutil.copy(downloaded, video_path)
83
+ except Exception as e:
84
+ print(f" ⚠ Could not download {video_id}: {e}")
85
+ return None
86
+ return video_path
87
+
88
+ def get_video_capture(video_id: str):
89
+ """Get or create video capture object."""
90
+ if video_id not in video_cache:
91
+ video_path = download_video(video_id)
92
+ if video_path:
93
+ video_cache[video_id] = cv2.VideoCapture(str(video_path))
94
+ return video_cache.get(video_id)
95
+
96
+ def extract_frame(video_id: str, frame_idx: int) -> Image.Image:
97
+ """Extract frame from video."""
98
+ cap = get_video_capture(video_id)
99
+ if cap is None:
100
+ return None
101
+
102
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
103
+ ret, frame = cap.read()
104
+
105
+ if ret:
106
+ # Convert BGR to RGB
107
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ return Image.fromarray(frame_rgb)
109
+ return None
110
+
111
+ # ============================================================
112
+ # Process Dataset
113
+ # ============================================================
114
+
115
+ print("\n🔄 Processing samples and extracting frames...")
116
+
117
+ # Get unique video IDs first
118
+ video_ids = set()
119
+ for ex in ds:
120
+ video_ids.add(ex['video_id'])
121
+ print(f"Found {len(video_ids)} unique videos")
122
+
123
+ # Download videos first
124
+ print("\n📥 Downloading videos...")
125
+ for vid in tqdm(list(video_ids), desc="Videos"):
126
+ download_video(vid)
127
+
128
+ # Process samples
129
+ print("\n🖼️ Extracting frames...")
130
+ processed_samples = []
131
+ failed = 0
132
+
133
+ for i, ex in enumerate(tqdm(ds, desc="Samples")):
134
+ if i >= MAX_SAMPLES:
135
+ break
136
+
137
+ video_id = ex['video_id']
138
+ frame_idx = ex.get('frame_index', 0)
139
+
140
+ # Extract frame
141
+ frame = extract_frame(video_id, frame_idx)
142
+
143
+ if frame is None:
144
+ failed += 1
145
+ continue
146
+
147
+ # Create new sample with image
148
+ sample = {
149
+ "image": frame,
150
+ "video_id": video_id,
151
+ "frame_index": frame_idx,
152
+ "messages": ex['messages'],
153
+ }
154
+ processed_samples.append(sample)
155
+
156
+ print(f"\n✓ Processed {len(processed_samples)} samples ({failed} failed)")
157
+
158
+ # Close video captures
159
+ for cap in video_cache.values():
160
+ cap.release()
161
+
162
+ # ============================================================
163
+ # Create Dataset
164
+ # ============================================================
165
+
166
+ print("\n📊 Creating dataset...")
167
+
168
+ # Create dataset with Image feature
169
+ new_ds = Dataset.from_list(processed_samples)
170
+ print(f"✓ Created dataset with {len(new_ds)} samples")
171
+
172
+ # Check features
173
+ print(f"Features: {new_ds.features}")
174
+
175
+ # ============================================================
176
+ # Upload
177
+ # ============================================================
178
+
179
+ print(f"\n📤 Uploading to {OUTPUT_DATASET}...")
180
+
181
+ try:
182
+ new_ds.push_to_hub(OUTPUT_DATASET, private=False)
183
+ print(f"✓ Uploaded to https://huggingface.co/datasets/{OUTPUT_DATASET}")
184
+ except Exception as e:
185
+ print(f"⚠ Upload error: {e}")
186
+
187
+ # ============================================================
188
+ # Summary
189
+ # ============================================================
190
+
191
+ print("\n" + "=" * 60)
192
+ print("✅ DONE!")
193
+ print("=" * 60)
194
+ print(f"""
195
+ Dataset created: {OUTPUT_DATASET}
196
+ Samples: {len(processed_samples)}
197
+ Failed: {failed}
198
+
199
+ To use:
200
+ ```python
201
+ from datasets import load_dataset
202
+ ds = load_dataset("{OUTPUT_DATASET}")
203
+ # Images are directly available - no video extraction needed!
204
+ image = ds['train'][0]['image']
205
+ ```
206
+ """)