| import hashlib |
| import io |
| import os |
| import shutil |
| import tempfile |
| import traceback |
| import zipfile |
|
|
| import gradio as gr |
| import pandas as pd |
| from datasets import Dataset, Image |
| from huggingface_hub import HfApi, hf_hub_download |
| from PIL import Image as PILImage |
| from PIL import ImageFile, UnidentifiedImageError |
|
|
| |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") |
|
|
| |
| |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| SOURCE_REPO = "alkzar90/NIH-Chest-X-ray-dataset" |
| TARGET_REPO = "HlexNC/chest-xray-14" |
| NUM_ZIPS = 12 |
| IMG_SIZE = (224, 224) |
| HF_TOKEN = os.environ["HF_TOKEN"] |
| RESAMPLE = getattr(PILImage, "Resampling", PILImage).LANCZOS |
|
|
| api = HfApi(token=HF_TOKEN) |
|
|
|
|
| def _stable_bucket(filename: str) -> int: |
| digest = hashlib.sha256(filename.encode("utf-8")).digest() |
| return int.from_bytes(digest[:8], "big") % 10 |
|
|
|
|
| def _split_for(filename: str, train_files: set[str], test_files: set[str]) -> str: |
| """Return train, validation, test, or unknown for a given filename.""" |
| if filename in test_files: |
| return "test" |
| if filename in train_files: |
| return "train" if _stable_bucket(filename) < 9 else "validation" |
| return "unknown" |
|
|
|
|
| def repackage(progress=gr.Progress(track_tqdm=False)): |
| """Download, resize, shard, and upload the dataset on HF Space.""" |
| log_lines: list[str] = [] |
| unreadable_total = 0 |
| unknown_total = 0 |
| unreadable_samples: list[tuple[str, str, str]] = [] |
| unknown_samples: list[str] = [] |
| total_images = 0 |
|
|
| def emit(message: str) -> str: |
| log_lines.append(message) |
| return "\n".join(log_lines) |
|
|
| try: |
| progress(0.0, desc="Preparing metadata") |
| yield emit("Creating target dataset repo (if needed)...") |
| api.create_repo(repo_id=TARGET_REPO, repo_type="dataset", exist_ok=True) |
|
|
| yield emit("Downloading label CSV...") |
| csv_path = hf_hub_download( |
| repo_id=SOURCE_REPO, |
| filename="data/Data_Entry_2017_v2020.csv", |
| repo_type="dataset", |
| token=HF_TOKEN, |
| ) |
| labels_df = pd.read_csv(csv_path) |
| label_map = dict(zip(labels_df["Image Index"], labels_df["Finding Labels"])) |
| yield emit(f" Loaded {len(label_map):,} label rows.") |
|
|
| yield emit("Downloading split manifests...") |
| train_val_path = hf_hub_download( |
| repo_id=SOURCE_REPO, |
| filename="data/train_val_list.txt", |
| repo_type="dataset", |
| token=HF_TOKEN, |
| ) |
| with open(train_val_path, encoding="utf-8") as handle: |
| train_val_files = {line.strip() for line in handle if line.strip()} |
|
|
| test_path = hf_hub_download( |
| repo_id=SOURCE_REPO, |
| filename="data/test_list.txt", |
| repo_type="dataset", |
| token=HF_TOKEN, |
| ) |
| with open(test_path, encoding="utf-8") as handle: |
| test_files = {line.strip() for line in handle if line.strip()} |
|
|
| yield emit( |
| f" Loaded {len(train_val_files):,} train/validation names and " |
| f"{len(test_files):,} test names." |
| ) |
|
|
| for zip_idx in range(1, NUM_ZIPS + 1): |
| zip_name = f"images_{zip_idx:03d}.zip" |
| remote_path = f"data/images/{zip_name}" |
| shard_tag = f"{zip_idx - 1:05d}-of-{NUM_ZIPS:05d}" |
| zip_unreadable = 0 |
| zip_unknown = 0 |
|
|
| progress((zip_idx - 1) / NUM_ZIPS, desc=f"Processing {zip_name}") |
| yield emit(f"\n-- Zip {zip_idx}/{NUM_ZIPS}: {zip_name} --") |
| yield emit(f" Downloading {remote_path} ...") |
| zip_path = hf_hub_download( |
| repo_id=SOURCE_REPO, |
| filename=remote_path, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| ) |
| yield emit(f" Downloaded {remote_path}.") |
|
|
| buckets: dict[str, dict[str, list]] = { |
| "train": {"image": [], "labels": [], "filename": []}, |
| "validation": {"image": [], "labels": [], "filename": []}, |
| "test": {"image": [], "labels": [], "filename": []}, |
| } |
|
|
| yield emit(" Extracting and resizing images...") |
| with zipfile.ZipFile(zip_path, "r") as archive: |
| members = [ |
| info |
| for info in archive.infolist() |
| if not info.is_dir() |
| and info.filename.lower().endswith((".png", ".jpg", ".jpeg")) |
| ] |
| yield emit(f" Found {len(members):,} candidate image entries.") |
|
|
| for index, info in enumerate(members, start=1): |
| member = info.filename |
| filename = os.path.basename(member) |
| if not filename: |
| continue |
|
|
| split = _split_for(filename, train_val_files, test_files) |
| if split == "unknown": |
| zip_unknown += 1 |
| unknown_total += 1 |
| if len(unknown_samples) < 20: |
| unknown_samples.append(f"{zip_name}:{member}") |
| continue |
|
|
| try: |
| data = archive.read(info) |
| if not data: |
| raise ValueError("empty file") |
| with PILImage.open(io.BytesIO(data)) as raw_image: |
| image = raw_image.convert("RGB") |
| image = image.resize(IMG_SIZE, RESAMPLE) |
| except (UnidentifiedImageError, OSError, ValueError) as exc: |
| zip_unreadable += 1 |
| unreadable_total += 1 |
| if len(unreadable_samples) < 25: |
| unreadable_samples.append( |
| (zip_name, member, f"{type(exc).__name__}: {exc}") |
| ) |
| if zip_unreadable <= 5 or zip_unreadable % 50 == 0: |
| yield emit( |
| " Skipping unreadable image " |
| f"{member} ({type(exc).__name__}: {exc})" |
| ) |
| continue |
|
|
| buckets[split]["image"].append(image) |
| buckets[split]["labels"].append(label_map.get(filename, "No Finding")) |
| buckets[split]["filename"].append(filename) |
| total_images += 1 |
|
|
| if index % 2000 == 0: |
| yield emit(f" ... scanned {index:,}/{len(members):,} entries") |
|
|
| yield emit( |
| " Zip summary: " |
| f"train={len(buckets['train']['filename']):,}, " |
| f"validation={len(buckets['validation']['filename']):,}, " |
| f"test={len(buckets['test']['filename']):,}, " |
| f"skipped_unreadable={zip_unreadable:,}, " |
| f"skipped_unknown={zip_unknown:,}" |
| ) |
|
|
| tmpdir = tempfile.mkdtemp() |
| try: |
| written_parquets: list[str] = [] |
| for split_name, data_dict in buckets.items(): |
| if not data_dict["filename"]: |
| continue |
|
|
| parquet_name = f"{split_name}-{shard_tag}.parquet" |
| parquet_path = os.path.join(tmpdir, parquet_name) |
|
|
| yield emit(f" Writing {parquet_name} ...") |
| dataset = Dataset.from_dict(data_dict).cast_column("image", Image()) |
| dataset.to_parquet(parquet_path) |
| written_parquets.append(parquet_name) |
|
|
| if written_parquets: |
| yield emit( |
| f" Uploading {len(written_parquets)} shard(s) to " |
| f"{TARGET_REPO} ..." |
| ) |
| api.upload_folder( |
| folder_path=tmpdir, |
| path_in_repo="data", |
| repo_id=TARGET_REPO, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| commit_message=f"Add parquet shards for {zip_name}", |
| ) |
| yield emit(" Upload complete.") |
| finally: |
| shutil.rmtree(tmpdir, ignore_errors=True) |
|
|
| try: |
| os.remove(zip_path) |
| yield emit(" Deleted local zip cache.") |
| except OSError: |
| yield emit(" Local zip cache cleanup skipped.") |
|
|
| progress(1.0, desc="Done") |
| yield emit("") |
| yield emit("Repackaging complete.") |
| yield emit(f"Total images kept: {total_images:,}") |
| yield emit(f"Unreadable images skipped: {unreadable_total:,}") |
| yield emit(f"Files missing from split manifests: {unknown_total:,}") |
| yield emit(f"Dataset repo: https://huggingface.co/datasets/{TARGET_REPO}") |
|
|
| if unreadable_samples: |
| yield emit("") |
| yield emit("Unreadable image samples:") |
| for zip_name, member, reason in unreadable_samples: |
| yield emit(f" - {zip_name}:{member} -> {reason}") |
|
|
| if unknown_samples: |
| yield emit("") |
| yield emit("Unknown split samples:") |
| for sample in unknown_samples: |
| yield emit(f" - {sample}") |
| except Exception: |
| yield emit("") |
| yield emit("ERROR") |
| yield emit(traceback.format_exc()) |
|
|
|
|
| with gr.Blocks(title="CheXVision Data Pipeline") as demo: |
| gr.Markdown( |
| "# CheXVision Data Pipeline\n\n" |
| "One-click repackaging of the NIH Chest X-ray14 dataset " |
| "(112,120 images, about 45 GB) from " |
| "[alkzar90/NIH-Chest-X-ray-dataset]" |
| "(https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset) " |
| "into pre-processed 224x224 parquet shards at " |
| "[HlexNC/chest-xray-14]" |
| "(https://huggingface.co/datasets/HlexNC/chest-xray-14).\n\n" |
| "The job runs entirely on Hugging Face Space hardware. Progress and " |
| "logs appear below while the pipeline is running." |
| ) |
| start_btn = gr.Button("Start Repackaging", variant="primary") |
| output_box = gr.Textbox( |
| label="Log output", |
| lines=25, |
| max_lines=50, |
| interactive=False, |
| ) |
| start_btn.click(fn=repackage, inputs=[], outputs=[output_box]) |
|
|
| demo.queue(default_concurrency_limit=1) |
|
|
| if __name__ == "__main__": |
| demo.launch(show_error=True, ssr_mode=False) |
|
|