|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Create dataset with embedded images from pitvqa-comprehensive-spatial. |
|
|
|
|
|
Extracts video frames and embeds them directly in the dataset. |
|
|
This eliminates the need for video extraction during training/inference. |
|
|
|
|
|
Run with: hf jobs uv run --flavor cpu-xlarge --secrets HF_TOKEN create_image_dataset.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SOURCE_DATASET = "mmrech/pitvqa-comprehensive-spatial" |
|
|
VIDEO_DATASET = "UCL-WEISS/PitVis-2023" |
|
|
OUTPUT_DATASET = "mmrech/pitvqa-spatial-with-images" |
|
|
|
|
|
VIDEO_CACHE = Path("/tmp/videos") |
|
|
VIDEO_CACHE.mkdir(exist_ok=True) |
|
|
|
|
|
MAX_SAMPLES = 1000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import login, HfApi, hf_hub_download |
|
|
from datasets import load_dataset, Dataset, Features, Value, Image as ImageFeature |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
print("โ Logged in to HuggingFace") |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n๐ฆ Loading source dataset...") |
|
|
ds = load_dataset(SOURCE_DATASET, split="train") |
|
|
print(f"โ Loaded {len(ds)} samples") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_cache = {} |
|
|
|
|
|
def download_video(video_id: str) -> Path: |
|
|
"""Download video if not cached.""" |
|
|
video_path = VIDEO_CACHE / f"{video_id}.mp4" |
|
|
if not video_path.exists(): |
|
|
try: |
|
|
downloaded = hf_hub_download( |
|
|
repo_id=VIDEO_DATASET, |
|
|
filename=f"videos/{video_id}.mp4", |
|
|
repo_type="dataset" |
|
|
) |
|
|
import shutil |
|
|
shutil.copy(downloaded, video_path) |
|
|
except Exception as e: |
|
|
print(f" โ Could not download {video_id}: {e}") |
|
|
return None |
|
|
return video_path |
|
|
|
|
|
def get_video_capture(video_id: str): |
|
|
"""Get or create video capture object.""" |
|
|
if video_id not in video_cache: |
|
|
video_path = download_video(video_id) |
|
|
if video_path: |
|
|
video_cache[video_id] = cv2.VideoCapture(str(video_path)) |
|
|
return video_cache.get(video_id) |
|
|
|
|
|
def extract_frame(video_id: str, frame_idx: int) -> Image.Image: |
|
|
"""Extract frame from video.""" |
|
|
cap = get_video_capture(video_id) |
|
|
if cap is None: |
|
|
return None |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
|
ret, frame = cap.read() |
|
|
|
|
|
if ret: |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
return Image.fromarray(frame_rgb) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n๐ Processing samples and extracting frames...") |
|
|
|
|
|
|
|
|
video_ids = set() |
|
|
for ex in ds: |
|
|
video_ids.add(ex['video_id']) |
|
|
print(f"Found {len(video_ids)} unique videos") |
|
|
|
|
|
|
|
|
print("\n๐ฅ Downloading videos...") |
|
|
for vid in tqdm(list(video_ids), desc="Videos"): |
|
|
download_video(vid) |
|
|
|
|
|
|
|
|
print("\n๐ผ๏ธ Extracting frames...") |
|
|
processed_samples = [] |
|
|
failed = 0 |
|
|
|
|
|
for i, ex in enumerate(tqdm(ds, desc="Samples")): |
|
|
if i >= MAX_SAMPLES: |
|
|
break |
|
|
|
|
|
video_id = ex['video_id'] |
|
|
frame_idx = ex.get('frame_index', 0) |
|
|
|
|
|
|
|
|
frame = extract_frame(video_id, frame_idx) |
|
|
|
|
|
if frame is None: |
|
|
failed += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
sample = { |
|
|
"image": frame, |
|
|
"video_id": video_id, |
|
|
"frame_index": frame_idx, |
|
|
"messages": ex['messages'], |
|
|
} |
|
|
processed_samples.append(sample) |
|
|
|
|
|
print(f"\nโ Processed {len(processed_samples)} samples ({failed} failed)") |
|
|
|
|
|
|
|
|
for cap in video_cache.values(): |
|
|
cap.release() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n๐ Creating dataset...") |
|
|
|
|
|
|
|
|
new_ds = Dataset.from_list(processed_samples) |
|
|
print(f"โ Created dataset with {len(new_ds)} samples") |
|
|
|
|
|
|
|
|
print(f"Features: {new_ds.features}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n๐ค Uploading to {OUTPUT_DATASET}...") |
|
|
|
|
|
try: |
|
|
new_ds.push_to_hub(OUTPUT_DATASET, private=False) |
|
|
print(f"โ Uploaded to https://huggingface.co/datasets/{OUTPUT_DATASET}") |
|
|
except Exception as e: |
|
|
print(f"โ Upload error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("โ
DONE!") |
|
|
print("=" * 60) |
|
|
print(f""" |
|
|
Dataset created: {OUTPUT_DATASET} |
|
|
Samples: {len(processed_samples)} |
|
|
Failed: {failed} |
|
|
|
|
|
To use: |
|
|
```python |
|
|
from datasets import load_dataset |
|
|
ds = load_dataset("{OUTPUT_DATASET}") |
|
|
# Images are directly available - no video extraction needed! |
|
|
image = ds['train'][0]['image'] |
|
|
``` |
|
|
""") |
|
|
|