| """ |
| download.py |
| |
| Utility functions for downloading and extracting various datasets to (local) disk. |
| """ |
|
|
| import os |
| import shutil |
| from pathlib import Path |
| from typing import Dict, List, TypedDict |
| from zipfile import ZipFile |
|
|
| import requests |
| from PIL import Image |
| from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn |
| from tqdm import tqdm |
|
|
| |
|
|
| |
| |
|
|
|
|
| |
| |
| DatasetComponent = TypedDict( |
| "DatasetComponent", |
| {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, |
| total=False |
| ) |
|
|
| DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { |
| |
|
|
| |
| |
| "llava-v1.5-instruct": |
| [ |
| { |
| "name": "coco/train2017", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "http://images.cocodataset.org/zips/train2017.zip", |
| "do_rename": True, |
| }, |
| { |
| "name": "gqa/images", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", |
| "do_rename": True, |
| }, |
| { |
| "name": "ocr_vqa/images", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", |
| "do_rename": True, |
| }, |
| { |
| "name": "textvqa/train_images", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", |
| "do_rename": True, |
| }, |
| { |
| "name": "vg/VG_100K", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", |
| "do_rename": True, |
| }, |
| { |
| "name": "vg/VG_100K_2", |
| "extract": True, |
| "extract_type": "directory", |
| "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", |
| "do_rename": True, |
| }, |
| ] |
| } |
| |
|
|
|
|
| def convert_to_jpg(image_dir: Path) -> None: |
| """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" |
| print(f"Converting all Images in `{image_dir}` to JPG") |
|
|
| for image_fn in tqdm(list(image_dir.iterdir())): |
| if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): |
| continue |
|
|
| if image_fn.suffix == ".gif": |
| gif = Image.open(image_fn) |
| gif.seek(0) |
| gif.convert("RGB").save(jpg_fn) |
| elif image_fn.suffix == ".png": |
| Image.open(image_fn).convert("RGB").save(jpg_fn) |
| else: |
| raise ValueError(f"Unexpected image format `{image_fn.suffix}`") |
|
|
|
|
| def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: |
| """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" |
| print(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) |
| if dest_path.exists(): |
| return dest_path |
|
|
| |
| response = requests.get(url, stream=True) |
|
|
| |
| |
| with Progress( |
| TextColumn("[bold]{task.description} - {task.fields[fname]}"), |
| BarColumn(bar_width=None), |
| "[progress.percentage]{task.percentage:>3.1f}%", |
| "•", |
| DownloadColumn(), |
| "•", |
| TransferSpeedColumn(), |
| transient=True, |
| ) as dl_progress: |
| dl_tid = dl_progress.add_task( |
| "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) |
| ) |
| with open(dest_path, "wb") as f: |
| for data in response.iter_content(chunk_size=chunk_size_bytes): |
| dl_progress.advance(dl_tid, f.write(data)) |
|
|
| return dest_path |
|
|
|
|
| def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: |
| """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" |
| assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" |
| print(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) |
|
|
| |
| with Progress( |
| TextColumn("[bold]{task.description} - {task.fields[aname]}"), |
| BarColumn(bar_width=None), |
| "[progress.percentage]{task.percentage:>3.1f}%", |
| "•", |
| MofNCompleteColumn(), |
| transient=True, |
| ) as ext_progress: |
| with ZipFile(archive_path) as zf: |
| ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) |
| extract_path = Path(zf.extract(members[0], download_dir)) |
| if extract_type == "file": |
| assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" |
| elif extract_type == "directory": |
| for member in members[1:]: |
| zf.extract(member, download_dir) |
| ext_progress.advance(ext_tid) |
| else: |
| raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") |
|
|
| |
| if cleanup: |
| archive_path.unlink() |
|
|
| return extract_path |
|
|
|
|
| def download_extract(dataset_id: str, root_dir: Path) -> None: |
| """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" |
| os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) |
|
|
| |
| dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] |
| for dl_task in dl_tasks: |
| dl_path = download_with_progress(dl_task["url"], download_dir) |
|
|
| |
| if dl_task["extract"]: |
| dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) |
| dl_path = dl_path.parent if dl_path.is_file() else dl_path |
|
|
| |
| if dl_task["do_rename"]: |
| shutil.move(dl_path, download_dir / dl_task["name"]) |
| if __name__ == "__main__": |
| import sys |
| from pathlib import Path |
|
|
| |
| root_dir = Path("./data") |
| os.makedirs(root_dir, exist_ok=True) |
|
|
| |
| for dataset_id in DATASET_REGISTRY.keys(): |
| print(f"开始下载数据集: {dataset_id}") |
| download_extract(dataset_id, root_dir) |
|
|
| print("所有数据集下载完成!") |
|
|