SDK-Docker / prepare_dataset.py
Lucifer9907's picture
Prepare Hugging Face Docker Space
ff0c419
from __future__ import annotations
import argparse
import base64
import io
import time
from pathlib import Path
import requests
from PIL import Image
from src.ai_image_detector.config import PROCESSED_DATA_DIR
def save_image(image_bytes: bytes, destination: Path) -> None:
destination.parent.mkdir(parents=True, exist_ok=True)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image.save(destination, format="PNG")
def save_pil_image(image: Image.Image, destination: Path) -> None:
destination.parent.mkdir(parents=True, exist_ok=True)
image.convert("RGB").save(destination, format="PNG")
def download_image_url(image_url: str, destination: Path, retries: int = 5) -> None:
last_error: Exception | None = None
for attempt in range(retries):
try:
response = requests.get(image_url, timeout=120)
response.raise_for_status()
save_image(response.content, destination)
return
except requests.RequestException as exc:
last_error = exc
time.sleep(2 * (attempt + 1))
raise RuntimeError(f"Failed to download image after {retries} attempts") from last_error
def fetch_rows(
dataset_name: str,
offset: int,
length: int,
retries: int = 5,
) -> list[dict]:
rows_api = "https://datasets-server.huggingface.co/rows"
last_error: Exception | None = None
for attempt in range(retries):
try:
response = requests.get(
rows_api,
params={
"dataset": dataset_name,
"config": "default",
"split": "train",
"offset": offset,
"length": length,
},
timeout=120,
)
response.raise_for_status()
return response.json().get("rows", [])
except requests.RequestException as exc:
last_error = exc
if isinstance(exc, requests.HTTPError) and exc.response is not None and exc.response.status_code == 429:
time.sleep(30 * (attempt + 1))
else:
time.sleep(2 * (attempt + 1))
raise RuntimeError(f"Failed to fetch dataset rows after {retries} attempts") from last_error
def download_subset(
dataset_name: str,
real_target: int,
fake_target: int,
seed: int,
start_offset: int,
) -> tuple[int, int]:
real_dir = PROCESSED_DATA_DIR / "real"
fake_dir = PROCESSED_DATA_DIR / "fake"
real_dir.mkdir(parents=True, exist_ok=True)
fake_dir.mkdir(parents=True, exist_ok=True)
real_count = len(list(real_dir.glob("*.png")))
fake_count = len(list(fake_dir.glob("*.png")))
if dataset_name == "OwensLab/CommunityForensics-Small":
offset = start_offset
page_size = 10
while real_count < real_target or fake_count < fake_target:
rows = fetch_rows(dataset_name, offset, page_size)
if not rows:
break
for item in rows:
row = item["row"]
label = int(row["label"])
image_name = Path(row["image_name"]).stem
image_bytes = base64.b64decode(row["image_data"])
if label == 0 and real_count < real_target:
destination = real_dir / f"communityforensics_real_{real_count:05d}_{image_name}.png"
save_image(image_bytes, destination)
real_count += 1
elif label == 1 and fake_count < fake_target:
destination = fake_dir / f"communityforensics_fake_{fake_count:05d}_{image_name}.png"
save_image(image_bytes, destination)
fake_count += 1
if real_count >= real_target and fake_count >= fake_target:
break
offset += page_size
elif dataset_name == "Parveshiiii/AI-vs-Real":
offset = start_offset
page_size = 20
while real_count < real_target or fake_count < fake_target:
rows = fetch_rows(dataset_name, offset, page_size)
if not rows:
break
for item in rows:
row = item["row"]
label = int(row["binary_label"])
image_url = row["image"]["src"]
row_idx = item["row_idx"]
if label == 1 and real_count < real_target:
destination = real_dir / f"aivsreal_real_{real_count:05d}_{row_idx}.png"
download_image_url(image_url, destination)
real_count += 1
elif label == 0 and fake_count < fake_target:
destination = fake_dir / f"aivsreal_fake_{fake_count:05d}_{row_idx}.png"
download_image_url(image_url, destination)
fake_count += 1
if real_count >= real_target and fake_count >= fake_target:
break
offset += page_size
else:
raise ValueError(f"Unsupported dataset source: {dataset_name}")
return real_count, fake_count
def main() -> None:
parser = argparse.ArgumentParser(
description="Download a balanced starter subset of CommunityForensics-Small."
)
parser.add_argument(
"--per-class",
type=int,
default=2000,
help="Number of real and fake images to download",
)
parser.add_argument(
"--real-count",
type=int,
default=None,
help="Override the real-image target count",
)
parser.add_argument(
"--fake-count",
type=int,
default=None,
help="Override the fake-image target count",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for stream shuffling",
)
parser.add_argument(
"--dataset",
type=str,
default="OwensLab/CommunityForensics-Small",
help="Hugging Face dataset id",
)
parser.add_argument(
"--start-offset",
type=int,
default=0,
help="Starting row offset for resumable API downloads",
)
args = parser.parse_args()
real_target = args.per_class if args.real_count is None else args.real_count
fake_target = args.per_class if args.fake_count is None else args.fake_count
real_count, fake_count = download_subset(
dataset_name=args.dataset,
real_target=real_target,
fake_target=fake_target,
seed=args.seed,
start_offset=args.start_offset,
)
print(f"Downloaded {real_count} real images to {PROCESSED_DATA_DIR / 'real'}")
print(f"Downloaded {fake_count} fake images to {PROCESSED_DATA_DIR / 'fake'}")
if __name__ == "__main__":
main()