diff --git "a/dataloading.ipynb" "b/dataloading.ipynb" new file mode 100644--- /dev/null +++ "b/dataloading.ipynb" @@ -0,0 +1,1333 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "ae9bc87a", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b2ffd47f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0ccb4dc0c6bf4c8f89a0be03b742598f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/119 [00:00", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRemoteTraceback\u001b[39m Traceback (most recent call last)", + "\u001b[31mRemoteTraceback\u001b[39m: \n\"\"\"\nTraceback (most recent call last):\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/pool.py\", line 125, in worker\n result = (True, func(*args, **kwds))\n ^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/utils/py_utils.py\", line 586, in _write_generator_to_queue\n for i, result in enumerate(func(**kwargs)):\n ^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py\", line 3697, in _map_single\n for i, batch in iter_outputs(shard_iterable):\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py\", line 3647, in iter_outputs\n yield i, apply_function(example, i, offset=offset)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py\", line 3570, in apply_function\n processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/var/folders/07/1hr7xxpj3sx52fsnpz87jfj40000gp/T/ipykernel_65139/2515864513.py\", line 27, in save_batch_raw\n images = batch['image']\n ~~~~~^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py\", line 285, in __getitem__\n value = self.format(key)\n ^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py\", line 385, in format\n return self.formatter.format_column(self.pa_table.select([key]))\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py\", line 465, in format_column\n column = self.python_features_decoder.decode_column(column, pa_table.column_names[0])\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py\", line 228, in decode_column\n self.features.decode_column(column, column_name, token_per_repo_id=self.token_per_repo_id)\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/features.py\", line 2130, in decode_column\n decode_nested_example(self[column_name], value, token_per_repo_id=token_per_repo_id)\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/features.py\", line 1414, in decode_nested_example\n return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/image.py\", line 192, in decode_example\n image = PIL.Image.open(BytesIO(bytes_))\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/Users/arpit-zstch1557/miniconda3/envs/lab_env/lib/python3.12/site-packages/PIL/Image.py\", line 3580, in open\n raise UnidentifiedImageError(msg)\nPIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x3250f9d50>\n\"\"\"", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[31mUnidentifiedImageError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[48]\u001b[39m\u001b[32m, line 51\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m splits:\n\u001b[32m 49\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m🚀 Processing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msplit.upper()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m split...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m51\u001b[39m \u001b[43mds\u001b[49m\u001b[43m[\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 52\u001b[39m \u001b[43m \u001b[49m\u001b[43msave_batch_raw\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 53\u001b[39m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 54\u001b[39m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Process 100 images per chunk\u001b[39;49;00m\n\u001b[32m 55\u001b[39m \u001b[43m \u001b[49m\u001b[43mwith_indices\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Needed for unique filenames\u001b[39;49;00m\n\u001b[32m 56\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[43m=\u001b[49m\u001b[43mNUM_PROC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# <--- THIS IS THE SPEED BOOST\u001b[39;49;00m\n\u001b[32m 57\u001b[39m \u001b[43m \u001b[49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43m{\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43msplit_name\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 58\u001b[39m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mSaving \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msplit\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m\"\u001b[39;49m\n\u001b[32m 59\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 61\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m✅ DONE! Raw data saved to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mos.path.abspath(OUTPUT_DIR)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py:562\u001b[39m, in \u001b[36mtransmit_format..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 555\u001b[39m self_format = {\n\u001b[32m 556\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtype\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_type,\n\u001b[32m 557\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mformat_kwargs\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_kwargs,\n\u001b[32m 558\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcolumns\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_columns,\n\u001b[32m 559\u001b[39m \u001b[33m\"\u001b[39m\u001b[33moutput_all_columns\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._output_all_columns,\n\u001b[32m 560\u001b[39m }\n\u001b[32m 561\u001b[39m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m562\u001b[39m out: Union[\u001b[33m\"\u001b[39m\u001b[33mDataset\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mDatasetDict\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 563\u001b[39m datasets: \u001b[38;5;28mlist\u001b[39m[\u001b[33m\"\u001b[39m\u001b[33mDataset\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28mlist\u001b[39m(out.values()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[32m 564\u001b[39m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py:3332\u001b[39m, in \u001b[36mDataset.map\u001b[39m\u001b[34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc, try_original_type)\u001b[39m\n\u001b[32m 3329\u001b[39m os.environ = prev_env\n\u001b[32m 3330\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSpawning \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_proc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m processes\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m3332\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miflatmap_unordered\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3333\u001b[39m \u001b[43m \u001b[49m\u001b[43mpool\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDataset\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_map_single\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs_iterable\u001b[49m\u001b[43m=\u001b[49m\u001b[43munprocessed_kwargs_per_job\u001b[49m\n\u001b[32m 3334\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 3335\u001b[39m \u001b[43m \u001b[49m\u001b[43mcheck_if_shard_done\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3337\u001b[39m pool.close()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/utils/py_utils.py:626\u001b[39m, in \u001b[36miflatmap_unordered\u001b[39m\u001b[34m(pool, func, kwargs_iterable)\u001b[39m\n\u001b[32m 623\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 624\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[32m 625\u001b[39m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m626\u001b[39m [\u001b[43masync_result\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/pool.py:774\u001b[39m, in \u001b[36mApplyResult.get\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 772\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._value\n\u001b[32m 773\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m774\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m._value\n", + "\u001b[31mUnidentifiedImageError\u001b[39m: cannot identify image file <_io.BytesIO object at 0x3250f9d50>" + ] + } + ], + "source": [ + "import os\n", + "import multiprocessing\n", + "\n", + "# 1. Configuration\n", + "OUTPUT_DIR = \"rvl_cdip\"\n", + "NUM_PROC = os.cpu_count() # Automatically use all CPU cores\n", + "\n", + "# 2. Pre-Calculate Class Names\n", + "# We do this once so workers don't have to look it up repeatedly\n", + "labels_feature = ds['train'].features['label']\n", + "idx_to_class = {idx: name for idx, name in enumerate(labels_feature.names)}\n", + "print(f\"✅ Using {NUM_PROC} workers to save RAW images.\")\n", + "\n", + "# 3. Pre-Create Directories\n", + "# Create all folders upfront to prevent collision errors\n", + "print(\"Creating directory structure...\")\n", + "splits = ['train', 'val', 'test']\n", + "for split in splits:\n", + " for class_name in idx_to_class.values():\n", + " os.makedirs(os.path.join(OUTPUT_DIR, split, class_name), exist_ok=True)\n", + "\n", + "# 4. The Worker Function (Raw Save)\n", + "def save_batch_raw(batch, indices, split_name):\n", + " \"\"\"\n", + " Saves a batch of images in their original, raw format.\n", + " \"\"\"\n", + " images = batch['image']\n", + " labels = batch['label']\n", + " \n", + " for img, label_idx, original_idx in zip(images, labels, indices):\n", + " class_name = idx_to_class[label_idx]\n", + " \n", + " # Define Path\n", + " filename = f\"{original_idx}.png\"\n", + " file_path = os.path.join(OUTPUT_DIR, split_name, class_name, filename)\n", + " \n", + " # Save RAW (No Resize)\n", + " # We only convert to RGB if absolutely necessary (e.g. CMYK/Transparency issues)\n", + " # otherwise we save as is.\n", + " if img.mode not in ['RGB', 'L']: # 'L' is standard grayscale\n", + " img = img.convert('RGB')\n", + " \n", + " img.save(file_path)\n", + " \n", + " return batch\n", + "\n", + "# 5. Execute Parallel Processing\n", + "for split in splits:\n", + " print(f\"\\n🚀 Processing {split.upper()} split...\")\n", + " \n", + " ds[split].map(\n", + " save_batch_raw,\n", + " batched=True,\n", + " batch_size=100, # Process 100 images per chunk\n", + " with_indices=True, # Needed for unique filenames\n", + " num_proc=NUM_PROC, # <--- THIS IS THE SPEED BOOST\n", + " fn_kwargs={'split_name': split},\n", + " desc=f\"Saving {split}\"\n", + " )\n", + "\n", + "print(f\"\\n✅ DONE! Raw data saved to {os.path.abspath(OUTPUT_DIR)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "5645bccb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🛠️ Repairing TEST split (with integrity check)...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "22efa388600c486783774cefabee4455", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Checking test: 0%| | 0/40000 [00:00", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mUnidentifiedImageError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[51]\u001b[39m\u001b[32m, line 18\u001b[39m\n\u001b[32m 15\u001b[39m current_ds = ds[split]\n\u001b[32m 16\u001b[39m skipped_count = \u001b[32m0\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexample\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtqdm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcurrent_ds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mChecking \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msplit\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mtry\u001b[39;49;00m\u001b[43m:\u001b[49m\n\u001b[32m 20\u001b[39m \u001b[43m \u001b[49m\u001b[43mlabel_idx\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mexample\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mlabel\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/tqdm/notebook.py:250\u001b[39m, in \u001b[36mtqdm_notebook.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 248\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 249\u001b[39m it = \u001b[38;5;28msuper\u001b[39m().\u001b[34m__iter__\u001b[39m()\n\u001b[32m--> \u001b[39m\u001b[32m250\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mit\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 251\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# return super(tqdm...) will not catch exception\u001b[39;49;00m\n\u001b[32m 252\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[32m 253\u001b[39m \u001b[38;5;66;03m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/tqdm/std.py:1181\u001b[39m, in \u001b[36mtqdm.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 1178\u001b[39m time = \u001b[38;5;28mself\u001b[39m._time\n\u001b[32m 1180\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1181\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 1182\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[32m 1183\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[32m 1184\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py:2483\u001b[39m, in \u001b[36mDataset.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 2481\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(pa_subtable.num_rows):\n\u001b[32m 2482\u001b[39m pa_subtable_ex = pa_subtable.slice(i, \u001b[32m1\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m2483\u001b[39m formatted_output = \u001b[43mformat_table\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2484\u001b[39m \u001b[43m \u001b[49m\u001b[43mpa_subtable_ex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2485\u001b[39m \u001b[43m \u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 2486\u001b[39m \u001b[43m \u001b[49m\u001b[43mformatter\u001b[49m\u001b[43m=\u001b[49m\u001b[43mformatter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2487\u001b[39m \u001b[43m \u001b[49m\u001b[43mformat_columns\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_format_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2488\u001b[39m \u001b[43m \u001b[49m\u001b[43moutput_all_columns\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_output_all_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2489\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2490\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m formatted_output\n\u001b[32m 2491\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py:658\u001b[39m, in \u001b[36mformat_table\u001b[39m\u001b[34m(table, key, formatter, format_columns, output_all_columns)\u001b[39m\n\u001b[32m 656\u001b[39m python_formatter = PythonFormatter(features=formatter.features)\n\u001b[32m 657\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m format_columns \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m658\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mformatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquery_type\u001b[49m\u001b[43m=\u001b[49m\u001b[43mquery_type\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m query_type == \u001b[33m\"\u001b[39m\u001b[33mcolumn\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 660\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m format_columns:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py:411\u001b[39m, in \u001b[36mFormatter.__call__\u001b[39m\u001b[34m(self, pa_table, query_type)\u001b[39m\n\u001b[32m 409\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, pa_table: pa.Table, query_type: \u001b[38;5;28mstr\u001b[39m) -> Union[RowFormat, ColumnFormat, BatchFormat]:\n\u001b[32m 410\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m query_type == \u001b[33m\"\u001b[39m\u001b[33mrow\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m411\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mformat_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 412\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m query_type == \u001b[33m\"\u001b[39m\u001b[33mcolumn\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 413\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.format_column(pa_table)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py:460\u001b[39m, in \u001b[36mPythonFormatter.format_row\u001b[39m\u001b[34m(self, pa_table)\u001b[39m\n\u001b[32m 458\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m LazyRow(pa_table, \u001b[38;5;28mself\u001b[39m)\n\u001b[32m 459\u001b[39m row = \u001b[38;5;28mself\u001b[39m.python_arrow_extractor().extract_row(pa_table)\n\u001b[32m--> \u001b[39m\u001b[32m460\u001b[39m row = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpython_features_decoder\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdecode_row\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 461\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m row\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/formatting/formatting.py:224\u001b[39m, in \u001b[36mPythonFeaturesDecoder.decode_row\u001b[39m\u001b[34m(self, row)\u001b[39m\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecode_row\u001b[39m(\u001b[38;5;28mself\u001b[39m, row: \u001b[38;5;28mdict\u001b[39m) -> \u001b[38;5;28mdict\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m224\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.features \u001b[38;5;28;01melse\u001b[39;00m row\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/features.py:2106\u001b[39m, in \u001b[36mFeatures.decode_example\u001b[39m\u001b[34m(self, example, token_per_repo_id)\u001b[39m\n\u001b[32m 2091\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecode_example\u001b[39m(\u001b[38;5;28mself\u001b[39m, example: \u001b[38;5;28mdict\u001b[39m, token_per_repo_id: Optional[\u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Union[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m]]] = \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 2092\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Decode example with custom feature decoding.\u001b[39;00m\n\u001b[32m 2093\u001b[39m \n\u001b[32m 2094\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 2102\u001b[39m \u001b[33;03m `dict[str, Any]`\u001b[39;00m\n\u001b[32m 2103\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m 2105\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[32m-> \u001b[39m\u001b[32m2106\u001b[39m column_name: \u001b[43mdecode_nested_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2107\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._column_requires_decoding[column_name]\n\u001b[32m 2108\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m value\n\u001b[32m 2109\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m column_name, (feature, value) \u001b[38;5;129;01min\u001b[39;00m zip_dict(\n\u001b[32m 2110\u001b[39m {key: value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.items() \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m example}, example\n\u001b[32m 2111\u001b[39m )\n\u001b[32m 2112\u001b[39m }\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/features.py:1414\u001b[39m, in \u001b[36mdecode_nested_example\u001b[39m\u001b[34m(schema, obj, token_per_repo_id)\u001b[39m\n\u001b[32m 1411\u001b[39m \u001b[38;5;66;03m# Object with special decoding:\u001b[39;00m\n\u001b[32m 1412\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(schema, \u001b[33m\"\u001b[39m\u001b[33mdecode_example\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(schema, \u001b[33m\"\u001b[39m\u001b[33mdecode\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[32m 1413\u001b[39m \u001b[38;5;66;03m# we pass the token to read and decode files from private repositories in streaming mode\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1414\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mschema\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdecode_example\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken_per_repo_id\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m obj \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1415\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m obj\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/features/image.py:192\u001b[39m, in \u001b[36mImage.decode_example\u001b[39m\u001b[34m(self, value, token_per_repo_id)\u001b[39m\n\u001b[32m 190\u001b[39m image = PIL.Image.open(bytes_)\n\u001b[32m 191\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m192\u001b[39m image = \u001b[43mPIL\u001b[49m\u001b[43m.\u001b[49m\u001b[43mImage\u001b[49m\u001b[43m.\u001b[49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mBytesIO\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbytes_\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 193\u001b[39m image.load() \u001b[38;5;66;03m# to avoid \"Too many open files\" errors\u001b[39;00m\n\u001b[32m 194\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m image.getexif().get(PIL.Image.ExifTags.Base.Orientation) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/PIL/Image.py:3580\u001b[39m, in \u001b[36mopen\u001b[39m\u001b[34m(fp, mode, formats)\u001b[39m\n\u001b[32m 3578\u001b[39m warnings.warn(message)\n\u001b[32m 3579\u001b[39m msg = \u001b[33m\"\u001b[39m\u001b[33mcannot identify image file \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m\"\u001b[39m % (filename \u001b[38;5;28;01mif\u001b[39;00m filename \u001b[38;5;28;01melse\u001b[39;00m fp)\n\u001b[32m-> \u001b[39m\u001b[32m3580\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m UnidentifiedImageError(msg)\n", + "\u001b[31mUnidentifiedImageError\u001b[39m: cannot identify image file <_io.BytesIO object at 0x10abe6f20>" + ] + } + ], + "source": [ + "import os\n", + "from tqdm.auto import tqdm\n", + "from PIL import UnidentifiedImageError\n", + "\n", + "# Configuration\n", + "OUTPUT_DIR = \"rvl_cdip\"\n", + "split = \"test\" \n", + "\n", + "# Get Class Mapping\n", + "labels_feature = ds['train'].features['label']\n", + "idx_to_class = {idx: name for idx, name in enumerate(labels_feature.names)}\n", + "\n", + "print(f\"🛠️ Repairing {split.upper()} split (with integrity check)...\")\n", + "\n", + "current_ds = ds[split]\n", + "skipped_count = 0\n", + "\n", + "for i, example in enumerate(tqdm(current_ds, desc=f\"Checking {split}\")):\n", + " try:\n", + " label_idx = example['label']\n", + " class_name = idx_to_class[label_idx]\n", + " \n", + " target_folder = os.path.join(OUTPUT_DIR, split, class_name)\n", + " filename = f\"{i}.png\"\n", + " file_path = os.path.join(target_folder, filename)\n", + " \n", + " # --- IMPROVED CHECK ---\n", + " # Only skip if file exists AND is not empty (larger than 0 bytes)\n", + " # This fixes the edge case where the crash left a 0-byte file\n", + " if os.path.exists(file_path) and os.path.getsize(file_path) > 0:\n", + " continue\n", + " \n", + " # If we reach here, the file is missing OR corrupt (empty). So we save it.\n", + " if not os.path.exists(target_folder):\n", + " os.makedirs(target_folder, exist_ok=True)\n", + "\n", + " image = example['image'] \n", + " if image.mode not in ['RGB', 'L']:\n", + " image = image.convert('RGB')\n", + " \n", + " image.save(file_path)\n", + "\n", + " except (UnidentifiedImageError, OSError) as e:\n", + " print(f\"\\n❌ SKIPPING CORRUPT IMAGE: Index {i}\")\n", + " skipped_count += 1\n", + "\n", + "print(f\"\\n✅ Repair Complete.\")\n", + "print(f\"Total corrupt/unreadable images skipped: {skipped_count}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "41f94f27", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🛠️ Repairing TEST split (Safe Mode)...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b4725f0cad048119e282d008f56fe5f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Checking test: 0%| | 0/40000 [00:00 0:\n", + " continue\n", + " \n", + " # D. Create Folder\n", + " if not os.path.exists(target_folder):\n", + " os.makedirs(target_folder, exist_ok=True)\n", + "\n", + " # E. Manual Decoding (The Safe Way)\n", + " # Hugging Face gives us a dict with 'bytes' when decoding is off\n", + " image_data = example['image'] \n", + " \n", + " # Check if it's already a PIL object (some versions vary)\n", + " if isinstance(image_data, dict) and 'bytes' in image_data:\n", + " image_bytes = image_data['bytes']\n", + " image = Image.open(io.BytesIO(image_bytes))\n", + " else:\n", + " image = image_data # It might already be loaded\n", + " \n", + " # F. Save\n", + " if image.mode not in ['RGB', 'L']:\n", + " image = image.convert('RGB')\n", + " \n", + " image.save(file_path)\n", + "\n", + " except (UnidentifiedImageError, OSError, ValueError) as e:\n", + " print(f\"\\n❌ SKIPPING CORRUPT IMAGE: Index {i}\")\n", + " # Create a placeholder or just log it\n", + " skipped_count += 1\n", + "\n", + "print(f\"\\n✅ Repair Complete.\")\n", + "print(f\"Skipped {skipped_count} corrupt files.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "b829c704", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({13: 2572,\n", + " 12: 2537,\n", + " 5: 2532,\n", + " 3: 2527,\n", + " 2: 2516,\n", + " 0: 2515,\n", + " 4: 2506,\n", + " 1: 2505,\n", + " 14: 2498,\n", + " 8: 2492,\n", + " 10: 2489,\n", + " 6: 2477,\n", + " 15: 2472,\n", + " 7: 2464,\n", + " 9: 2463,\n", + " 11: 2435})" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Counter(ds[split]['label'])" + ] + }, + { + "cell_type": "markdown", + "id": "c8530c8e", + "metadata": {}, + "source": [ + "## Checking the balance in ds" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "2785360c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SPLIT CLASS NAME COUNT STATUS\n", + "------------------------------------------------------------\n", + "TRAIN advertisement 19963 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN budget 20010 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN email 19954 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN file folder 20022 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN form 19957 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN handwritten 20034 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN invoice 19947 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN letter 20106 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN memo 19975 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN news article 20011 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN presentation 20043 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN questionnaire 20048 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN resume 20036 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN scientific publication 19902 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN scientific report 19994 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN specification 19997 ❌ MISMATCH (Exp: 20000)\n", + "------------------------------------------------------------\n", + "VAL advertisement 2522 ❌ MISMATCH (Exp: 2500)\n", + "VAL budget 2485 ❌ MISMATCH (Exp: 2500)\n", + "VAL email 2530 ❌ MISMATCH (Exp: 2500)\n", + "VAL file folder 2451 ❌ MISMATCH (Exp: 2500)\n", + "VAL form 2537 ❌ MISMATCH (Exp: 2500)\n", + "VAL handwritten 2434 ❌ MISMATCH (Exp: 2500)\n", + "VAL invoice 2576 ❌ MISMATCH (Exp: 2500)\n", + "VAL letter 2430 ❌ MISMATCH (Exp: 2500)\n", + "VAL memo 2533 ❌ MISMATCH (Exp: 2500)\n", + "VAL news article 2526 ❌ MISMATCH (Exp: 2500)\n", + "VAL presentation 2468 ❌ MISMATCH (Exp: 2500)\n", + "VAL questionnaire 2517 ❌ MISMATCH (Exp: 2500)\n", + "VAL resume 2426 ❌ MISMATCH (Exp: 2500)\n", + "VAL scientific publication 2526 ❌ MISMATCH (Exp: 2500)\n", + "VAL scientific report 2508 ❌ MISMATCH (Exp: 2500)\n", + "VAL specification 2531 ❌ MISMATCH (Exp: 2500)\n", + "------------------------------------------------------------\n", + "TEST advertisement 2515 ❌ MISMATCH (Exp: 2500)\n", + "TEST budget 2505 ❌ MISMATCH (Exp: 2500)\n", + "TEST email 2516 ❌ MISMATCH (Exp: 2500)\n", + "TEST file folder 2527 ❌ MISMATCH (Exp: 2500)\n", + "TEST form 2506 ❌ MISMATCH (Exp: 2500)\n", + "TEST handwritten 2532 ❌ MISMATCH (Exp: 2500)\n", + "TEST invoice 2477 ❌ MISMATCH (Exp: 2500)\n", + "TEST letter 2464 ❌ MISMATCH (Exp: 2500)\n", + "TEST memo 2492 ❌ MISMATCH (Exp: 2500)\n", + "TEST news article 2463 ❌ MISMATCH (Exp: 2500)\n", + "TEST presentation 2489 ❌ MISMATCH (Exp: 2500)\n", + "TEST questionnaire 2435 ❌ MISMATCH (Exp: 2500)\n", + "TEST resume 2537 ❌ MISMATCH (Exp: 2500)\n", + "TEST scientific publication 2572 ❌ MISMATCH (Exp: 2500)\n", + "TEST scientific report 2498 ❌ MISMATCH (Exp: 2500)\n", + "TEST specification 2472 ❌ MISMATCH (Exp: 2500)\n", + "------------------------------------------------------------\n" + ] + } + ], + "source": [ + "from collections import Counter\n", + "import pandas as pd\n", + "\n", + "# 1. Setup\n", + "splits = ['train', 'val', 'test']\n", + "label_feature = ds['train'].features['label']\n", + "int2str = label_feature.int2str # Helper to convert ID (0) -> Name (\"letter\")\n", + "\n", + "print(f\"{'SPLIT':<10} {'CLASS NAME':<25} {'COUNT':<10} {'STATUS'}\")\n", + "print(\"-\" * 60)\n", + "\n", + "for split in splits:\n", + " # 2. Get all labels (Load only the label column into memory)\n", + " # This is instant compared to loading images\n", + " labels = ds[split]['label']\n", + " \n", + " # 3. Count frequencies\n", + " counts = Counter(labels)\n", + " \n", + " # 4. Analyze each class\n", + " # We sort by class ID to keep it organized\n", + " for label_id in sorted(counts.keys()):\n", + " count = counts[label_id]\n", + " class_name = int2str(label_id)\n", + " \n", + " # 5. Define Expected Counts based on the Paper\n", + " # Train: 320k / 16 = 20,000\n", + " # Test/Val: 40k / 16 = 2,500\n", + " if split == 'train':\n", + " expected = 20000\n", + " else:\n", + " expected = 2500\n", + " \n", + " status = \"✅ OK\" if count == expected else f\"❌ MISMATCH (Exp: {expected})\"\n", + " \n", + " print(f\"{split.upper():<10} {class_name:<25} {count:<10} {status}\")\n", + " \n", + " print(\"-\" * 60) " + ] + }, + { + "cell_type": "markdown", + "id": "5f7b75a2", + "metadata": {}, + "source": [ + "## checking the balance in dir" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "059bfaa5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📂 Scanning directory: /Users/arpit-zstch1557/Projects/DL/Course 4/document-classification/rvl_cdip\n", + "SPLIT CLASS NAME FILES STATUS\n", + "-----------------------------------------------------------------\n", + "TRAIN advertisement 19963 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN budget 20010 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN email 19954 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN file folder 20022 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN form 19957 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN handwritten 20034 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN invoice 19947 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN letter 20106 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN memo 19975 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN news article 20011 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN presentation 20043 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN questionnaire 20048 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN resume 20036 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN scientific publication 19902 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN scientific report 19994 ❌ MISMATCH (Exp: 20000)\n", + "TRAIN specification 19997 ⚠️ OK (Off by 3)\n", + "-----------------------------------------------------------------\n", + "VAL advertisement 2522 ❌ MISMATCH (Exp: 2500)\n", + "VAL budget 2485 ❌ MISMATCH (Exp: 2500)\n", + "VAL email 2530 ❌ MISMATCH (Exp: 2500)\n", + "VAL file folder 2451 ❌ MISMATCH (Exp: 2500)\n", + "VAL form 2537 ❌ MISMATCH (Exp: 2500)\n", + "VAL handwritten 2434 ❌ MISMATCH (Exp: 2500)\n", + "VAL invoice 2576 ❌ MISMATCH (Exp: 2500)\n", + "VAL letter 2430 ❌ MISMATCH (Exp: 2500)\n", + "VAL memo 2533 ❌ MISMATCH (Exp: 2500)\n", + "VAL news article 2526 ❌ MISMATCH (Exp: 2500)\n", + "VAL presentation 2468 ❌ MISMATCH (Exp: 2500)\n", + "VAL questionnaire 2517 ❌ MISMATCH (Exp: 2500)\n", + "VAL resume 2426 ❌ MISMATCH (Exp: 2500)\n", + "VAL scientific publication 2526 ❌ MISMATCH (Exp: 2500)\n", + "VAL scientific report 2508 ❌ MISMATCH (Exp: 2500)\n", + "VAL specification 2531 ❌ MISMATCH (Exp: 2500)\n", + "-----------------------------------------------------------------\n", + "TEST advertisement 2515 ❌ MISMATCH (Exp: 2500)\n", + "TEST budget 2505 ❌ MISMATCH (Exp: 2500)\n", + "TEST email 2516 ❌ MISMATCH (Exp: 2500)\n", + "TEST file folder 2527 ❌ MISMATCH (Exp: 2500)\n", + "TEST form 2506 ❌ MISMATCH (Exp: 2500)\n", + "TEST handwritten 2532 ❌ MISMATCH (Exp: 2500)\n", + "TEST invoice 2477 ❌ MISMATCH (Exp: 2500)\n", + "TEST letter 2464 ❌ MISMATCH (Exp: 2500)\n", + "TEST memo 2492 ❌ MISMATCH (Exp: 2500)\n", + "TEST news article 2463 ❌ MISMATCH (Exp: 2500)\n", + "TEST presentation 2489 ❌ MISMATCH (Exp: 2500)\n", + "TEST questionnaire 2435 ❌ MISMATCH (Exp: 2500)\n", + "TEST resume 2537 ❌ MISMATCH (Exp: 2500)\n", + "TEST scientific publication 2571 ❌ MISMATCH (Exp: 2500)\n", + "TEST scientific report 2498 ⚠️ OK (Off by 2)\n", + "TEST specification 2472 ❌ MISMATCH (Exp: 2500)\n", + "-----------------------------------------------------------------\n", + "\n", + "Analysis Complete.\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "\n", + "# Configuration\n", + "DATA_DIR = \"rvl_cdip\" # Your directory name\n", + "splits = ['train', 'val', 'test']\n", + "\n", + "print(f\"📂 Scanning directory: {os.path.abspath(DATA_DIR)}\")\n", + "print(f\"{'SPLIT':<10} {'CLASS NAME':<25} {'FILES':<10} {'STATUS'}\")\n", + "print(\"-\" * 65)\n", + "\n", + "stats = []\n", + "\n", + "for split in splits:\n", + " split_dir = os.path.join(DATA_DIR, split)\n", + " \n", + " # Check if split folder exists\n", + " if not os.path.exists(split_dir):\n", + " print(f\"❌ Missing folder: {split}\")\n", + " continue\n", + " \n", + " # Get all class folders (sorted for consistency)\n", + " try:\n", + " classes = sorted([d for d in os.listdir(split_dir) if os.path.isdir(os.path.join(split_dir, d))])\n", + " except OSError:\n", + " print(f\"❌ Error reading {split} directory.\")\n", + " continue\n", + "\n", + " for cls in classes:\n", + " cls_path = os.path.join(split_dir, cls)\n", + " \n", + " # Count actual files (ignoring hidden files like .DS_Store)\n", + " file_count = len([\n", + " name for name in os.listdir(cls_path) \n", + " if os.path.isfile(os.path.join(cls_path, name)) \n", + " and not name.startswith('.')\n", + " ])\n", + " \n", + " # Determine Expected Count based on the paper\n", + " if split == 'train':\n", + " expected = 20000 \n", + " # Note: We know 'train' has 1 missing file in total from the source (319,999)\n", + " else:\n", + " expected = 2500\n", + "\n", + " # Status Check\n", + " # We allow a small tolerance because we know source data has noise\n", + " if file_count == expected:\n", + " status = \"✅ OK\"\n", + " elif abs(file_count - expected) < 5: \n", + " # If off by 1-4 files (like the corrupt one we skipped), it's acceptable\n", + " status = f\"⚠️ OK (Off by {expected - file_count})\"\n", + " else:\n", + " status = f\"❌ MISMATCH (Exp: {expected})\"\n", + " \n", + " print(f\"{split.upper():<10} {cls:<25} {file_count:<10} {status}\")\n", + " \n", + " print(\"-\" * 65)\n", + "\n", + "print(\"\\nAnalysis Complete.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "99ca1af8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🚀 Starting RVL-CDIP Downloader\n", + " Target Folder: /Users/arpit-zstch1557/Projects/DL/Course 4/document-classification/rvl_cdip_data\n", + " Workers: 12\n", + " Loading dataset structure from Hugging Face...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cf6fb62548ea45ebad0d16066ad8a895", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/119 [00:00 \u001b[39m\u001b[32m612\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m \u001b[43mqueue\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 613\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Empty:\n", + "\u001b[36mFile \u001b[39m\u001b[32m:2\u001b[39m, in \u001b[36mget\u001b[39m\u001b[34m(self, *args, **kwds)\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/managers.py:828\u001b[39m, in \u001b[36mBaseProxy._callmethod\u001b[39m\u001b[34m(self, methodname, args, kwds)\u001b[39m\n\u001b[32m 827\u001b[39m conn.send((\u001b[38;5;28mself\u001b[39m._id, methodname, args, kwds))\n\u001b[32m--> \u001b[39m\u001b[32m828\u001b[39m kind, result = \u001b[43mconn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrecv\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 830\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m kind == \u001b[33m'\u001b[39m\u001b[33m#RETURN\u001b[39m\u001b[33m'\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/connection.py:253\u001b[39m, in \u001b[36m_ConnectionBase.recv\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 252\u001b[39m \u001b[38;5;28mself\u001b[39m._check_readable()\n\u001b[32m--> \u001b[39m\u001b[32m253\u001b[39m buf = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_recv_bytes\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _ForkingPickler.loads(buf.getbuffer())\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/connection.py:433\u001b[39m, in \u001b[36mConnection._recv_bytes\u001b[39m\u001b[34m(self, maxsize)\u001b[39m\n\u001b[32m 432\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_recv_bytes\u001b[39m(\u001b[38;5;28mself\u001b[39m, maxsize=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m433\u001b[39m buf = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_recv\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m4\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 434\u001b[39m size, = struct.unpack(\u001b[33m\"\u001b[39m\u001b[33m!i\u001b[39m\u001b[33m\"\u001b[39m, buf.getvalue())\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/connection.py:398\u001b[39m, in \u001b[36mConnection._recv\u001b[39m\u001b[34m(self, size, read)\u001b[39m\n\u001b[32m 397\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m remaining > \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m398\u001b[39m chunk = \u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremaining\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 399\u001b[39m n = \u001b[38;5;28mlen\u001b[39m(chunk)\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mTimeoutError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 103\u001b[39m\n\u001b[32m 100\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m datasets.ImageFolder(root=\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mOUTPUT_DIR\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/train\u001b[39m\u001b[33m'\u001b[39m\u001b[33m)\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[34m__name__\u001b[39m == \u001b[33m\"\u001b[39m\u001b[33m__main__\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 84\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 81\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m SPLITS:\n\u001b[32m 82\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m📦 Processing SPLIT: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msplit.upper()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m84\u001b[39m \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 85\u001b[39m \u001b[43m \u001b[49m\u001b[43msave_image_worker\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 86\u001b[39m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 87\u001b[39m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Process 100 images per task\u001b[39;49;00m\n\u001b[32m 88\u001b[39m \u001b[43m \u001b[49m\u001b[43mwith_indices\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# We need the index for the filename\u001b[39;49;00m\n\u001b[32m 89\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[43m=\u001b[49m\u001b[43mNUM_PROC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Parallel speed!\u001b[39;49;00m\n\u001b[32m 90\u001b[39m \u001b[43m \u001b[49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43m{\u001b[49m\n\u001b[32m 91\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43msplit_name\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 92\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43moutput_root\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mOUTPUT_DIR\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 93\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43midx_to_class\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43midx_to_class\u001b[49m\n\u001b[32m 94\u001b[39m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 95\u001b[39m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mSaving \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msplit\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m\"\u001b[39;49m\n\u001b[32m 96\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 98\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m✅ Download and Extraction Complete!\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 99\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m You can now load this in PyTorch using:\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py:562\u001b[39m, in \u001b[36mtransmit_format..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 555\u001b[39m self_format = {\n\u001b[32m 556\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtype\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_type,\n\u001b[32m 557\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mformat_kwargs\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_kwargs,\n\u001b[32m 558\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mcolumns\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._format_columns,\n\u001b[32m 559\u001b[39m \u001b[33m\"\u001b[39m\u001b[33moutput_all_columns\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m._output_all_columns,\n\u001b[32m 560\u001b[39m }\n\u001b[32m 561\u001b[39m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m562\u001b[39m out: Union[\u001b[33m\"\u001b[39m\u001b[33mDataset\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mDatasetDict\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 563\u001b[39m datasets: \u001b[38;5;28mlist\u001b[39m[\u001b[33m\"\u001b[39m\u001b[33mDataset\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28mlist\u001b[39m(out.values()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[32m 564\u001b[39m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/arrow_dataset.py:3332\u001b[39m, in \u001b[36mDataset.map\u001b[39m\u001b[34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc, try_original_type)\u001b[39m\n\u001b[32m 3329\u001b[39m os.environ = prev_env\n\u001b[32m 3330\u001b[39m logger.info(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSpawning \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_proc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m processes\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m3332\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miflatmap_unordered\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3333\u001b[39m \u001b[43m \u001b[49m\u001b[43mpool\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDataset\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_map_single\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs_iterable\u001b[49m\u001b[43m=\u001b[49m\u001b[43munprocessed_kwargs_per_job\u001b[49m\n\u001b[32m 3334\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 3335\u001b[39m \u001b[43m \u001b[49m\u001b[43mcheck_if_shard_done\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3337\u001b[39m pool.close()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/datasets/utils/py_utils.py:626\u001b[39m, in \u001b[36miflatmap_unordered\u001b[39m\u001b[34m(pool, func, kwargs_iterable)\u001b[39m\n\u001b[32m 623\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 624\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m pool_changed:\n\u001b[32m 625\u001b[39m \u001b[38;5;66;03m# we get the result in case there's an error to raise\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m626\u001b[39m [\u001b[43masync_result\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m async_result \u001b[38;5;129;01min\u001b[39;00m async_results]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/lab_env/lib/python3.12/site-packages/multiprocess/pool.py:770\u001b[39m, in \u001b[36mApplyResult.get\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 768\u001b[39m \u001b[38;5;28mself\u001b[39m.wait(timeout)\n\u001b[32m 769\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.ready():\n\u001b[32m--> \u001b[39m\u001b[32m770\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m\n\u001b[32m 771\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._success:\n\u001b[32m 772\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._value\n", + "\u001b[31mTimeoutError\u001b[39m: " + ] + } + ], + "source": [ + "import os\n", + "import io\n", + "import multiprocessing\n", + "from datasets import load_dataset, Image as HFImage\n", + "from PIL import Image, UnidentifiedImageError\n", + "\n", + "# ================= CONFIGURATION =================\n", + "OUTPUT_DIR = \"rvl_cdip_data\" # Where data will be saved\n", + "NUM_PROC = os.cpu_count() # Use all available CPU cores\n", + "SPLITS = ['train', 'val', 'test'] # Splits to process\n", + "# =================================================\n", + "\n", + "def save_image_worker(batch, indices, split_name, output_root, idx_to_class):\n", + " \"\"\"\n", + " Worker function that runs on multiple CPU cores.\n", + " Receives raw image bytes, decodes them safely, and saves to disk.\n", + " \"\"\"\n", + " # 1. Unpack batch\n", + " # Since we used decode=False, 'image' contains a dict with 'bytes'\n", + " images_data = batch['image'] \n", + " labels = batch['label']\n", + " \n", + " for i, (img_data, label_idx, original_idx) in enumerate(zip(images_data, labels, indices)):\n", + " try:\n", + " # 2. Determine Paths\n", + " class_name = idx_to_class[label_idx]\n", + " target_folder = os.path.join(output_root, split_name, class_name)\n", + " filename = f\"{original_idx}.png\"\n", + " file_path = os.path.join(target_folder, filename)\n", + " \n", + " # 3. RESUME LOGIC (The \"Skip\" Check)\n", + " # If file exists and is not empty, skip it.\n", + " if os.path.exists(file_path) and os.path.getsize(file_path) > 0:\n", + " continue\n", + " \n", + " # 4. Create Directory (Lazy Creation)\n", + " # We do this here to ensure it exists before writing\n", + " os.makedirs(target_folder, exist_ok=True)\n", + " \n", + " # 5. Decode Image Safely\n", + " # We manually open the bytes. If this fails, we catch the error below.\n", + " image_bytes = img_data['bytes']\n", + " with Image.open(io.BytesIO(image_bytes)) as img:\n", + " # Convert to RGB (standard for PyTorch ResNet)\n", + " if img.mode != 'RGB':\n", + " img = img.convert('RGB')\n", + " \n", + " # Save to disk\n", + " img.save(file_path)\n", + "\n", + " except (UnidentifiedImageError, OSError, ValueError) as e:\n", + " # 6. Error Handling\n", + " # Instead of crashing the whole script, we just log this one failure.\n", + " print(f\"[Worker] Skipping corrupt image ID {original_idx} in {split_name}: {e}\")\n", + " \n", + " return batch\n", + "\n", + "def main():\n", + " print(f\"🚀 Starting RVL-CDIP Downloader\")\n", + " print(f\" Target Folder: {os.path.abspath(OUTPUT_DIR)}\")\n", + " print(f\" Workers: {NUM_PROC}\")\n", + " \n", + " # 1. Load Dataset\n", + " # Assuming you are logged into Hugging Face or have access\n", + " print(\" Loading dataset structure from Hugging Face...\")\n", + " dataset = load_dataset(\"chainyo/rvl-cdip\") \n", + "\n", + " # 2. Setup Class Mapping\n", + " labels_feature = dataset['train'].features['label']\n", + " idx_to_class = {idx: name for idx, name in enumerate(labels_feature.names)}\n", + " print(f\" Found {len(idx_to_class)} categories.\")\n", + "\n", + " # 3. CRITICAL: Disable Auto-Decoding\n", + " # This prevents the Iterator from crashing when it hits a corrupt file.\n", + " # We will handle decoding manually in the worker function.\n", + " print(\" Configuring dataset for safe raw access...\")\n", + " for split in SPLITS:\n", + " dataset[split] = dataset[split].cast_column(\"image\", HFImage(decode=False))\n", + "\n", + " # 4. Execute Parallel Processing\n", + " for split in SPLITS:\n", + " print(f\"\\n📦 Processing SPLIT: {split.upper()}\")\n", + " \n", + " dataset[split].map(\n", + " save_image_worker,\n", + " batched=True,\n", + " batch_size=100, # Process 100 images per task\n", + " with_indices=True, # We need the index for the filename\n", + " num_proc=NUM_PROC, # Parallel speed!\n", + " fn_kwargs={\n", + " 'split_name': split,\n", + " 'output_root': OUTPUT_DIR,\n", + " 'idx_to_class': idx_to_class\n", + " },\n", + " desc=f\"Saving {split}\"\n", + " )\n", + "\n", + " print(f\"\\n✅ Download and Extraction Complete!\")\n", + " print(f\" You can now load this in PyTorch using:\")\n", + " print(f\" datasets.ImageFolder(root='{OUTPUT_DIR}/train')\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f440bb56", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lab_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}