| | import argparse |
| | import json |
| | import os |
| | import uuid |
| | import zipfile |
| | from PIL import Image |
| | import base64 |
| | from io import BytesIO |
| |
|
| | import braceexpand |
| | import webdataset as wds |
| |
|
| | arg_parser = argparse.ArgumentParser() |
| | arg_parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | help="Pass in the directory where the output shards (as tar files) will be written to.", |
| | ) |
| | arg_parser.add_argument( |
| | "--zip_files", |
| | type=str, |
| | help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip", |
| | ) |
| | arg_parser.add_argument( |
| | "--image_dir", |
| | type=str, |
| | help="Pass in the directory where the images have been downloaded to.", |
| | ) |
| | arg_parser.add_argument( |
| | "--num_files_per_shard", |
| | type=int, |
| | default=1000, |
| | ) |
| | args = arg_parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | doc_shards = list(braceexpand.braceexpand(args.zip_files)) |
| |
|
| | with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink: |
| | for idx in range(len(doc_shards)): |
| | |
| | with zipfile.ZipFile(doc_shards[idx], "r") as zip_file: |
| | |
| | json_filename = zip_file.namelist()[0] |
| | with zip_file.open(json_filename, "r") as json_file: |
| | for sample_data in json_file: |
| | |
| | sample_data = json.loads(sample_data) |
| | image_info = sample_data["image_info"] |
| | image_names = [image["image_name"] for image in image_info] |
| |
|
| | |
| | for img_idx, image_name in enumerate(image_names): |
| | try: |
| | |
| | img = Image.open( |
| | os.path.join(args.image_dir, str(idx), image_name) |
| | ).convert("RGB") |
| | buffered = BytesIO() |
| | img.save(buffered, format="JPEG") |
| | img_str = base64.b64encode(buffered.getvalue()) |
| |
|
| | |
| | sample_data["image_info"][img_idx][ |
| | "image_base64" |
| | ] = img_str.decode("utf-8") |
| | except FileNotFoundError: |
| | print( |
| | f"Did not find {image_name} downloaded. This can happen if the url is now 404." |
| | ) |
| | except Exception as e: |
| | print(f"Error processing {image_name}: {e}") |
| |
|
| | key_str = uuid.uuid4().hex |
| | sink.write({"__key__": key_str, "json": sample_data}) |
| |
|
| | if (idx + 1) % args.num_files_per_shard == 0: |
| | sink.next_stream() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|