File size: 5,694 Bytes
69375af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "torch",
#     "datasets>=2.18.0",
#     "pillow",
#     "opencv-python-headless",
#     "huggingface_hub>=0.21.0",
#     "av",
#     "tqdm",
# ]
# ///
"""
Create dataset with embedded images from pitvqa-comprehensive-spatial.

Extracts video frames and embeds them directly in the dataset.
This eliminates the need for video extraction during training/inference.

Run with: hf jobs uv run --flavor cpu-xlarge --secrets HF_TOKEN create_image_dataset.py
"""

import os
import cv2
from io import BytesIO
from PIL import Image
from pathlib import Path
from tqdm import tqdm

# ============================================================
# Config
# ============================================================

SOURCE_DATASET = "mmrech/pitvqa-comprehensive-spatial"
VIDEO_DATASET = "UCL-WEISS/PitVis-2023"
OUTPUT_DATASET = "mmrech/pitvqa-spatial-with-images"

VIDEO_CACHE = Path("/tmp/videos")
VIDEO_CACHE.mkdir(exist_ok=True)

MAX_SAMPLES = 1000  # Start with subset for testing

# ============================================================
# Setup
# ============================================================

from huggingface_hub import login, HfApi, hf_hub_download
from datasets import load_dataset, Dataset, Features, Value, Image as ImageFeature

hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("✓ Logged in to HuggingFace")

api = HfApi()

# ============================================================
# Load Source Dataset
# ============================================================

print("\n📦 Loading source dataset...")
ds = load_dataset(SOURCE_DATASET, split="train")
print(f"✓ Loaded {len(ds)} samples")

# ============================================================
# Video Helpers
# ============================================================

video_cache = {}

def download_video(video_id: str) -> Path:
    """Download video if not cached."""
    video_path = VIDEO_CACHE / f"{video_id}.mp4"
    if not video_path.exists():
        try:
            downloaded = hf_hub_download(
                repo_id=VIDEO_DATASET,
                filename=f"videos/{video_id}.mp4",
                repo_type="dataset"
            )
            import shutil
            shutil.copy(downloaded, video_path)
        except Exception as e:
            print(f"  ⚠ Could not download {video_id}: {e}")
            return None
    return video_path

def get_video_capture(video_id: str):
    """Get or create video capture object."""
    if video_id not in video_cache:
        video_path = download_video(video_id)
        if video_path:
            video_cache[video_id] = cv2.VideoCapture(str(video_path))
    return video_cache.get(video_id)

def extract_frame(video_id: str, frame_idx: int) -> Image.Image:
    """Extract frame from video."""
    cap = get_video_capture(video_id)
    if cap is None:
        return None

    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()

    if ret:
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return Image.fromarray(frame_rgb)
    return None

# ============================================================
# Process Dataset
# ============================================================

print("\n🔄 Processing samples and extracting frames...")

# Get unique video IDs first
video_ids = set()
for ex in ds:
    video_ids.add(ex['video_id'])
print(f"Found {len(video_ids)} unique videos")

# Download videos first
print("\n📥 Downloading videos...")
for vid in tqdm(list(video_ids), desc="Videos"):
    download_video(vid)

# Process samples
print("\n🖼️ Extracting frames...")
processed_samples = []
failed = 0

for i, ex in enumerate(tqdm(ds, desc="Samples")):
    if i >= MAX_SAMPLES:
        break

    video_id = ex['video_id']
    frame_idx = ex.get('frame_index', 0)

    # Extract frame
    frame = extract_frame(video_id, frame_idx)

    if frame is None:
        failed += 1
        continue

    # Create new sample with image
    sample = {
        "image": frame,
        "video_id": video_id,
        "frame_index": frame_idx,
        "messages": ex['messages'],
    }
    processed_samples.append(sample)

print(f"\n✓ Processed {len(processed_samples)} samples ({failed} failed)")

# Close video captures
for cap in video_cache.values():
    cap.release()

# ============================================================
# Create Dataset
# ============================================================

print("\n📊 Creating dataset...")

# Create dataset with Image feature
new_ds = Dataset.from_list(processed_samples)
print(f"✓ Created dataset with {len(new_ds)} samples")

# Check features
print(f"Features: {new_ds.features}")

# ============================================================
# Upload
# ============================================================

print(f"\n📤 Uploading to {OUTPUT_DATASET}...")

try:
    new_ds.push_to_hub(OUTPUT_DATASET, private=False)
    print(f"✓ Uploaded to https://huggingface.co/datasets/{OUTPUT_DATASET}")
except Exception as e:
    print(f"⚠ Upload error: {e}")

# ============================================================
# Summary
# ============================================================

print("\n" + "=" * 60)
print("✅ DONE!")
print("=" * 60)
print(f"""
Dataset created: {OUTPUT_DATASET}
Samples: {len(processed_samples)}
Failed: {failed}

To use:
```python
from datasets import load_dataset
ds = load_dataset("{OUTPUT_DATASET}")
# Images are directly available - no video extraction needed!
image = ds['train'][0]['image']
```
""")