import os import json import pandas as pd from datasets import load_dataset from PIL import Image import shutil from tqdm import tqdm def load_and_process(): dataset = load_dataset("poloclub/diffusiondb", split="train[:1000]") os.makedirs("processed/images", exist_ok=True) processed_data = [] for idx, sample in enumerate(tqdm(dataset)): image_id = f"{idx:06d}.png" if sample.get('image'): sample['image'].save(f"processed/images/{image_id}") data_entry = { "id": idx, "image_file": image_id, "prompt": sample.get('p', ''), "seed": sample.get('se', 0), "cfg_scale": sample.get('c', 0.0), "steps": sample.get('st', 0), "sampler": sample.get('sa', '') } processed_data.append(data_entry) return processed_data def save_data(data): with open("processed/data.json", "w") as f: json.dump(data, f) df = pd.DataFrame(data) df.to_csv("processed/data.csv", index=False) df.to_parquet("processed/data.parquet", index=False) def main(): data = load_and_process() save_data(data) print(f"Processed {len(data)} samples") if __name__ == "__main__": main()