pitvqa-training-scripts / create_image_dataset.py
mmrech's picture
Upload create_image_dataset.py with huggingface_hub
69375af verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "datasets>=2.18.0",
# "pillow",
# "opencv-python-headless",
# "huggingface_hub>=0.21.0",
# "av",
# "tqdm",
# ]
# ///
"""
Create dataset with embedded images from pitvqa-comprehensive-spatial.
Extracts video frames and embeds them directly in the dataset.
This eliminates the need for video extraction during training/inference.
Run with: hf jobs uv run --flavor cpu-xlarge --secrets HF_TOKEN create_image_dataset.py
"""
import os
import cv2
from io import BytesIO
from PIL import Image
from pathlib import Path
from tqdm import tqdm
# ============================================================
# Config
# ============================================================
SOURCE_DATASET = "mmrech/pitvqa-comprehensive-spatial"
VIDEO_DATASET = "UCL-WEISS/PitVis-2023"
OUTPUT_DATASET = "mmrech/pitvqa-spatial-with-images"
VIDEO_CACHE = Path("/tmp/videos")
VIDEO_CACHE.mkdir(exist_ok=True)
MAX_SAMPLES = 1000 # Start with subset for testing
# ============================================================
# Setup
# ============================================================
from huggingface_hub import login, HfApi, hf_hub_download
from datasets import load_dataset, Dataset, Features, Value, Image as ImageFeature
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("โœ“ Logged in to HuggingFace")
api = HfApi()
# ============================================================
# Load Source Dataset
# ============================================================
print("\n๐Ÿ“ฆ Loading source dataset...")
ds = load_dataset(SOURCE_DATASET, split="train")
print(f"โœ“ Loaded {len(ds)} samples")
# ============================================================
# Video Helpers
# ============================================================
video_cache = {}
def download_video(video_id: str) -> Path:
"""Download video if not cached."""
video_path = VIDEO_CACHE / f"{video_id}.mp4"
if not video_path.exists():
try:
downloaded = hf_hub_download(
repo_id=VIDEO_DATASET,
filename=f"videos/{video_id}.mp4",
repo_type="dataset"
)
import shutil
shutil.copy(downloaded, video_path)
except Exception as e:
print(f" โš  Could not download {video_id}: {e}")
return None
return video_path
def get_video_capture(video_id: str):
"""Get or create video capture object."""
if video_id not in video_cache:
video_path = download_video(video_id)
if video_path:
video_cache[video_id] = cv2.VideoCapture(str(video_path))
return video_cache.get(video_id)
def extract_frame(video_id: str, frame_idx: int) -> Image.Image:
"""Extract frame from video."""
cap = get_video_capture(video_id)
if cap is None:
return None
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame_rgb)
return None
# ============================================================
# Process Dataset
# ============================================================
print("\n๐Ÿ”„ Processing samples and extracting frames...")
# Get unique video IDs first
video_ids = set()
for ex in ds:
video_ids.add(ex['video_id'])
print(f"Found {len(video_ids)} unique videos")
# Download videos first
print("\n๐Ÿ“ฅ Downloading videos...")
for vid in tqdm(list(video_ids), desc="Videos"):
download_video(vid)
# Process samples
print("\n๐Ÿ–ผ๏ธ Extracting frames...")
processed_samples = []
failed = 0
for i, ex in enumerate(tqdm(ds, desc="Samples")):
if i >= MAX_SAMPLES:
break
video_id = ex['video_id']
frame_idx = ex.get('frame_index', 0)
# Extract frame
frame = extract_frame(video_id, frame_idx)
if frame is None:
failed += 1
continue
# Create new sample with image
sample = {
"image": frame,
"video_id": video_id,
"frame_index": frame_idx,
"messages": ex['messages'],
}
processed_samples.append(sample)
print(f"\nโœ“ Processed {len(processed_samples)} samples ({failed} failed)")
# Close video captures
for cap in video_cache.values():
cap.release()
# ============================================================
# Create Dataset
# ============================================================
print("\n๐Ÿ“Š Creating dataset...")
# Create dataset with Image feature
new_ds = Dataset.from_list(processed_samples)
print(f"โœ“ Created dataset with {len(new_ds)} samples")
# Check features
print(f"Features: {new_ds.features}")
# ============================================================
# Upload
# ============================================================
print(f"\n๐Ÿ“ค Uploading to {OUTPUT_DATASET}...")
try:
new_ds.push_to_hub(OUTPUT_DATASET, private=False)
print(f"โœ“ Uploaded to https://huggingface.co/datasets/{OUTPUT_DATASET}")
except Exception as e:
print(f"โš  Upload error: {e}")
# ============================================================
# Summary
# ============================================================
print("\n" + "=" * 60)
print("โœ… DONE!")
print("=" * 60)
print(f"""
Dataset created: {OUTPUT_DATASET}
Samples: {len(processed_samples)}
Failed: {failed}
To use:
```python
from datasets import load_dataset
ds = load_dataset("{OUTPUT_DATASET}")
# Images are directly available - no video extraction needed!
image = ds['train'][0]['image']
```
""")