import hashlib import io import os import shutil import tempfile import traceback import zipfile import gradio as gr import pandas as pd from datasets import Dataset, Image from huggingface_hub import HfApi, hf_hub_download from PIL import Image as PILImage from PIL import ImageFile, UnidentifiedImageError # Keep large Hub downloads on the fast transfer path available in Spaces. os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # The source dataset contains a small number of problematic members. Let Pillow # decode truncated-but-salvageable files, and skip the truly unreadable ones. ImageFile.LOAD_TRUNCATED_IMAGES = True SOURCE_REPO = "alkzar90/NIH-Chest-X-ray-dataset" TARGET_REPO = "HlexNC/chest-xray-14" NUM_ZIPS = 12 IMG_SIZE = (224, 224) HF_TOKEN = os.environ["HF_TOKEN"] RESAMPLE = getattr(PILImage, "Resampling", PILImage).LANCZOS api = HfApi(token=HF_TOKEN) def _stable_bucket(filename: str) -> int: digest = hashlib.sha256(filename.encode("utf-8")).digest() return int.from_bytes(digest[:8], "big") % 10 def _split_for(filename: str, train_files: set[str], test_files: set[str]) -> str: """Return train, validation, test, or unknown for a given filename.""" if filename in test_files: return "test" if filename in train_files: return "train" if _stable_bucket(filename) < 9 else "validation" return "unknown" def repackage(progress=gr.Progress(track_tqdm=False)): """Download, resize, shard, and upload the dataset on HF Space.""" log_lines: list[str] = [] unreadable_total = 0 unknown_total = 0 unreadable_samples: list[tuple[str, str, str]] = [] unknown_samples: list[str] = [] total_images = 0 def emit(message: str) -> str: log_lines.append(message) return "\n".join(log_lines) try: progress(0.0, desc="Preparing metadata") yield emit("Creating target dataset repo (if needed)...") api.create_repo(repo_id=TARGET_REPO, repo_type="dataset", exist_ok=True) yield emit("Downloading label CSV...") csv_path = hf_hub_download( repo_id=SOURCE_REPO, filename="data/Data_Entry_2017_v2020.csv", repo_type="dataset", token=HF_TOKEN, ) labels_df = pd.read_csv(csv_path) label_map = dict(zip(labels_df["Image Index"], labels_df["Finding Labels"])) yield emit(f" Loaded {len(label_map):,} label rows.") yield emit("Downloading split manifests...") train_val_path = hf_hub_download( repo_id=SOURCE_REPO, filename="data/train_val_list.txt", repo_type="dataset", token=HF_TOKEN, ) with open(train_val_path, encoding="utf-8") as handle: train_val_files = {line.strip() for line in handle if line.strip()} test_path = hf_hub_download( repo_id=SOURCE_REPO, filename="data/test_list.txt", repo_type="dataset", token=HF_TOKEN, ) with open(test_path, encoding="utf-8") as handle: test_files = {line.strip() for line in handle if line.strip()} yield emit( f" Loaded {len(train_val_files):,} train/validation names and " f"{len(test_files):,} test names." ) for zip_idx in range(1, NUM_ZIPS + 1): zip_name = f"images_{zip_idx:03d}.zip" remote_path = f"data/images/{zip_name}" shard_tag = f"{zip_idx - 1:05d}-of-{NUM_ZIPS:05d}" zip_unreadable = 0 zip_unknown = 0 progress((zip_idx - 1) / NUM_ZIPS, desc=f"Processing {zip_name}") yield emit(f"\n-- Zip {zip_idx}/{NUM_ZIPS}: {zip_name} --") yield emit(f" Downloading {remote_path} ...") zip_path = hf_hub_download( repo_id=SOURCE_REPO, filename=remote_path, repo_type="dataset", token=HF_TOKEN, ) yield emit(f" Downloaded {remote_path}.") buckets: dict[str, dict[str, list]] = { "train": {"image": [], "labels": [], "filename": []}, "validation": {"image": [], "labels": [], "filename": []}, "test": {"image": [], "labels": [], "filename": []}, } yield emit(" Extracting and resizing images...") with zipfile.ZipFile(zip_path, "r") as archive: members = [ info for info in archive.infolist() if not info.is_dir() and info.filename.lower().endswith((".png", ".jpg", ".jpeg")) ] yield emit(f" Found {len(members):,} candidate image entries.") for index, info in enumerate(members, start=1): member = info.filename filename = os.path.basename(member) if not filename: continue split = _split_for(filename, train_val_files, test_files) if split == "unknown": zip_unknown += 1 unknown_total += 1 if len(unknown_samples) < 20: unknown_samples.append(f"{zip_name}:{member}") continue try: data = archive.read(info) if not data: raise ValueError("empty file") with PILImage.open(io.BytesIO(data)) as raw_image: image = raw_image.convert("RGB") image = image.resize(IMG_SIZE, RESAMPLE) except (UnidentifiedImageError, OSError, ValueError) as exc: zip_unreadable += 1 unreadable_total += 1 if len(unreadable_samples) < 25: unreadable_samples.append( (zip_name, member, f"{type(exc).__name__}: {exc}") ) if zip_unreadable <= 5 or zip_unreadable % 50 == 0: yield emit( " Skipping unreadable image " f"{member} ({type(exc).__name__}: {exc})" ) continue buckets[split]["image"].append(image) buckets[split]["labels"].append(label_map.get(filename, "No Finding")) buckets[split]["filename"].append(filename) total_images += 1 if index % 2000 == 0: yield emit(f" ... scanned {index:,}/{len(members):,} entries") yield emit( " Zip summary: " f"train={len(buckets['train']['filename']):,}, " f"validation={len(buckets['validation']['filename']):,}, " f"test={len(buckets['test']['filename']):,}, " f"skipped_unreadable={zip_unreadable:,}, " f"skipped_unknown={zip_unknown:,}" ) tmpdir = tempfile.mkdtemp() try: written_parquets: list[str] = [] for split_name, data_dict in buckets.items(): if not data_dict["filename"]: continue parquet_name = f"{split_name}-{shard_tag}.parquet" parquet_path = os.path.join(tmpdir, parquet_name) yield emit(f" Writing {parquet_name} ...") dataset = Dataset.from_dict(data_dict).cast_column("image", Image()) dataset.to_parquet(parquet_path) written_parquets.append(parquet_name) if written_parquets: yield emit( f" Uploading {len(written_parquets)} shard(s) to " f"{TARGET_REPO} ..." ) api.upload_folder( folder_path=tmpdir, path_in_repo="data", repo_id=TARGET_REPO, repo_type="dataset", token=HF_TOKEN, commit_message=f"Add parquet shards for {zip_name}", ) yield emit(" Upload complete.") finally: shutil.rmtree(tmpdir, ignore_errors=True) try: os.remove(zip_path) yield emit(" Deleted local zip cache.") except OSError: yield emit(" Local zip cache cleanup skipped.") progress(1.0, desc="Done") yield emit("") yield emit("Repackaging complete.") yield emit(f"Total images kept: {total_images:,}") yield emit(f"Unreadable images skipped: {unreadable_total:,}") yield emit(f"Files missing from split manifests: {unknown_total:,}") yield emit(f"Dataset repo: https://huggingface.co/datasets/{TARGET_REPO}") if unreadable_samples: yield emit("") yield emit("Unreadable image samples:") for zip_name, member, reason in unreadable_samples: yield emit(f" - {zip_name}:{member} -> {reason}") if unknown_samples: yield emit("") yield emit("Unknown split samples:") for sample in unknown_samples: yield emit(f" - {sample}") except Exception: yield emit("") yield emit("ERROR") yield emit(traceback.format_exc()) with gr.Blocks(title="CheXVision Data Pipeline") as demo: gr.Markdown( "# CheXVision Data Pipeline\n\n" "One-click repackaging of the NIH Chest X-ray14 dataset " "(112,120 images, about 45 GB) from " "[alkzar90/NIH-Chest-X-ray-dataset]" "(https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset) " "into pre-processed 224x224 parquet shards at " "[HlexNC/chest-xray-14]" "(https://huggingface.co/datasets/HlexNC/chest-xray-14).\n\n" "The job runs entirely on Hugging Face Space hardware. Progress and " "logs appear below while the pipeline is running." ) start_btn = gr.Button("Start Repackaging", variant="primary") output_box = gr.Textbox( label="Log output", lines=25, max_lines=50, interactive=False, ) start_btn.click(fn=repackage, inputs=[], outputs=[output_box]) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(show_error=True, ssr_mode=False)