Hlex Helftd
Handle unreadable images and stream Space logs
633d07c verified
import hashlib
import io
import os
import shutil
import tempfile
import traceback
import zipfile
import gradio as gr
import pandas as pd
from datasets import Dataset, Image
from huggingface_hub import HfApi, hf_hub_download
from PIL import Image as PILImage
from PIL import ImageFile, UnidentifiedImageError
# Keep large Hub downloads on the fast transfer path available in Spaces.
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
# The source dataset contains a small number of problematic members. Let Pillow
# decode truncated-but-salvageable files, and skip the truly unreadable ones.
ImageFile.LOAD_TRUNCATED_IMAGES = True
SOURCE_REPO = "alkzar90/NIH-Chest-X-ray-dataset"
TARGET_REPO = "HlexNC/chest-xray-14"
NUM_ZIPS = 12
IMG_SIZE = (224, 224)
HF_TOKEN = os.environ["HF_TOKEN"]
RESAMPLE = getattr(PILImage, "Resampling", PILImage).LANCZOS
api = HfApi(token=HF_TOKEN)
def _stable_bucket(filename: str) -> int:
digest = hashlib.sha256(filename.encode("utf-8")).digest()
return int.from_bytes(digest[:8], "big") % 10
def _split_for(filename: str, train_files: set[str], test_files: set[str]) -> str:
"""Return train, validation, test, or unknown for a given filename."""
if filename in test_files:
return "test"
if filename in train_files:
return "train" if _stable_bucket(filename) < 9 else "validation"
return "unknown"
def repackage(progress=gr.Progress(track_tqdm=False)):
"""Download, resize, shard, and upload the dataset on HF Space."""
log_lines: list[str] = []
unreadable_total = 0
unknown_total = 0
unreadable_samples: list[tuple[str, str, str]] = []
unknown_samples: list[str] = []
total_images = 0
def emit(message: str) -> str:
log_lines.append(message)
return "\n".join(log_lines)
try:
progress(0.0, desc="Preparing metadata")
yield emit("Creating target dataset repo (if needed)...")
api.create_repo(repo_id=TARGET_REPO, repo_type="dataset", exist_ok=True)
yield emit("Downloading label CSV...")
csv_path = hf_hub_download(
repo_id=SOURCE_REPO,
filename="data/Data_Entry_2017_v2020.csv",
repo_type="dataset",
token=HF_TOKEN,
)
labels_df = pd.read_csv(csv_path)
label_map = dict(zip(labels_df["Image Index"], labels_df["Finding Labels"]))
yield emit(f" Loaded {len(label_map):,} label rows.")
yield emit("Downloading split manifests...")
train_val_path = hf_hub_download(
repo_id=SOURCE_REPO,
filename="data/train_val_list.txt",
repo_type="dataset",
token=HF_TOKEN,
)
with open(train_val_path, encoding="utf-8") as handle:
train_val_files = {line.strip() for line in handle if line.strip()}
test_path = hf_hub_download(
repo_id=SOURCE_REPO,
filename="data/test_list.txt",
repo_type="dataset",
token=HF_TOKEN,
)
with open(test_path, encoding="utf-8") as handle:
test_files = {line.strip() for line in handle if line.strip()}
yield emit(
f" Loaded {len(train_val_files):,} train/validation names and "
f"{len(test_files):,} test names."
)
for zip_idx in range(1, NUM_ZIPS + 1):
zip_name = f"images_{zip_idx:03d}.zip"
remote_path = f"data/images/{zip_name}"
shard_tag = f"{zip_idx - 1:05d}-of-{NUM_ZIPS:05d}"
zip_unreadable = 0
zip_unknown = 0
progress((zip_idx - 1) / NUM_ZIPS, desc=f"Processing {zip_name}")
yield emit(f"\n-- Zip {zip_idx}/{NUM_ZIPS}: {zip_name} --")
yield emit(f" Downloading {remote_path} ...")
zip_path = hf_hub_download(
repo_id=SOURCE_REPO,
filename=remote_path,
repo_type="dataset",
token=HF_TOKEN,
)
yield emit(f" Downloaded {remote_path}.")
buckets: dict[str, dict[str, list]] = {
"train": {"image": [], "labels": [], "filename": []},
"validation": {"image": [], "labels": [], "filename": []},
"test": {"image": [], "labels": [], "filename": []},
}
yield emit(" Extracting and resizing images...")
with zipfile.ZipFile(zip_path, "r") as archive:
members = [
info
for info in archive.infolist()
if not info.is_dir()
and info.filename.lower().endswith((".png", ".jpg", ".jpeg"))
]
yield emit(f" Found {len(members):,} candidate image entries.")
for index, info in enumerate(members, start=1):
member = info.filename
filename = os.path.basename(member)
if not filename:
continue
split = _split_for(filename, train_val_files, test_files)
if split == "unknown":
zip_unknown += 1
unknown_total += 1
if len(unknown_samples) < 20:
unknown_samples.append(f"{zip_name}:{member}")
continue
try:
data = archive.read(info)
if not data:
raise ValueError("empty file")
with PILImage.open(io.BytesIO(data)) as raw_image:
image = raw_image.convert("RGB")
image = image.resize(IMG_SIZE, RESAMPLE)
except (UnidentifiedImageError, OSError, ValueError) as exc:
zip_unreadable += 1
unreadable_total += 1
if len(unreadable_samples) < 25:
unreadable_samples.append(
(zip_name, member, f"{type(exc).__name__}: {exc}")
)
if zip_unreadable <= 5 or zip_unreadable % 50 == 0:
yield emit(
" Skipping unreadable image "
f"{member} ({type(exc).__name__}: {exc})"
)
continue
buckets[split]["image"].append(image)
buckets[split]["labels"].append(label_map.get(filename, "No Finding"))
buckets[split]["filename"].append(filename)
total_images += 1
if index % 2000 == 0:
yield emit(f" ... scanned {index:,}/{len(members):,} entries")
yield emit(
" Zip summary: "
f"train={len(buckets['train']['filename']):,}, "
f"validation={len(buckets['validation']['filename']):,}, "
f"test={len(buckets['test']['filename']):,}, "
f"skipped_unreadable={zip_unreadable:,}, "
f"skipped_unknown={zip_unknown:,}"
)
tmpdir = tempfile.mkdtemp()
try:
written_parquets: list[str] = []
for split_name, data_dict in buckets.items():
if not data_dict["filename"]:
continue
parquet_name = f"{split_name}-{shard_tag}.parquet"
parquet_path = os.path.join(tmpdir, parquet_name)
yield emit(f" Writing {parquet_name} ...")
dataset = Dataset.from_dict(data_dict).cast_column("image", Image())
dataset.to_parquet(parquet_path)
written_parquets.append(parquet_name)
if written_parquets:
yield emit(
f" Uploading {len(written_parquets)} shard(s) to "
f"{TARGET_REPO} ..."
)
api.upload_folder(
folder_path=tmpdir,
path_in_repo="data",
repo_id=TARGET_REPO,
repo_type="dataset",
token=HF_TOKEN,
commit_message=f"Add parquet shards for {zip_name}",
)
yield emit(" Upload complete.")
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
try:
os.remove(zip_path)
yield emit(" Deleted local zip cache.")
except OSError:
yield emit(" Local zip cache cleanup skipped.")
progress(1.0, desc="Done")
yield emit("")
yield emit("Repackaging complete.")
yield emit(f"Total images kept: {total_images:,}")
yield emit(f"Unreadable images skipped: {unreadable_total:,}")
yield emit(f"Files missing from split manifests: {unknown_total:,}")
yield emit(f"Dataset repo: https://huggingface.co/datasets/{TARGET_REPO}")
if unreadable_samples:
yield emit("")
yield emit("Unreadable image samples:")
for zip_name, member, reason in unreadable_samples:
yield emit(f" - {zip_name}:{member} -> {reason}")
if unknown_samples:
yield emit("")
yield emit("Unknown split samples:")
for sample in unknown_samples:
yield emit(f" - {sample}")
except Exception:
yield emit("")
yield emit("ERROR")
yield emit(traceback.format_exc())
with gr.Blocks(title="CheXVision Data Pipeline") as demo:
gr.Markdown(
"# CheXVision Data Pipeline\n\n"
"One-click repackaging of the NIH Chest X-ray14 dataset "
"(112,120 images, about 45 GB) from "
"[alkzar90/NIH-Chest-X-ray-dataset]"
"(https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset) "
"into pre-processed 224x224 parquet shards at "
"[HlexNC/chest-xray-14]"
"(https://huggingface.co/datasets/HlexNC/chest-xray-14).\n\n"
"The job runs entirely on Hugging Face Space hardware. Progress and "
"logs appear below while the pipeline is running."
)
start_btn = gr.Button("Start Repackaging", variant="primary")
output_box = gr.Textbox(
label="Log output",
lines=25,
max_lines=50,
interactive=False,
)
start_btn.click(fn=repackage, inputs=[], outputs=[output_box])
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(show_error=True, ssr_mode=False)