{ "cells": [ { "cell_type": "code", "execution_count": 7, "id": "ae9bc87a", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "import datasets\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "d5bc67fe", "metadata": {}, "outputs": [], "source": [ "ds = load_dataset(\"chainyo/rvl-cdip\")" ] }, { "cell_type": "markdown", "id": "85f49eeb", "metadata": {}, "source": [ "## Creates the \"rvl_cdip_data\" dir" ] }, { "cell_type": "code", "execution_count": null, "id": "936deafa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "šŸš€ Starting RVL-CDIP Downloader (Disk Optimized)\n", " Target Folder: /Users/arpit-zstch1557/Projects/document-classification/rvl_cdip_data\n", " Workers: 12\n", " Loading dataset structure from Hugging Face...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a79c4079dd44915af9193231077adc9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Resolving data files: 0%| | 0/119 [00:00\n", "\n", "āœ… Download and Extraction Complete!\n", " You can now load this in PyTorch using:\n", " datasets.ImageFolder(root='rvl_cdip_data/train')\n" ] } ], "source": [ "import os\n", "import io\n", "from datasets import load_dataset, Image as HFImage\n", "from PIL import Image, UnidentifiedImageError\n", "\n", "OUTPUT_DIR = \"rvl_cdip_data\" # Where data will be saved\n", "NUM_PROC = os.cpu_count() # Use all available CPU cores\n", "SPLITS = ['train', 'val', 'test'] # Splits to process\n", "\n", "def save_image_worker(batch, indices, split_name, output_root, idx_to_class):\n", " # Unpack batch\n", " images_data = batch['image'] \n", " labels = batch['label']\n", " \n", " for i, (img_data, label_idx, original_idx) in enumerate(zip(images_data, labels, indices)):\n", " try:\n", " # Determine Paths\n", " class_name = idx_to_class[label_idx]\n", " target_folder = os.path.join(output_root, split_name, class_name)\n", " filename = f\"{original_idx}.png\"\n", " file_path = os.path.join(target_folder, filename)\n", " \n", " if os.path.exists(file_path) and os.path.getsize(file_path) > 0:\n", " continue\n", " \n", " # Create Directory\n", " os.makedirs(target_folder, exist_ok=True)\n", " \n", " # 5. Decode Image Safely\n", " image_bytes = img_data['bytes']\n", " with Image.open(io.BytesIO(image_bytes)) as img:\n", " if img.mode != 'RGB':\n", " img = img.convert('RGB')\n", " img.save(file_path)\n", "\n", " except (UnidentifiedImageError, OSError, ValueError) as e:\n", " print(f\"[Worker] Skipping corrupt image ID {original_idx} in {split_name}: {e}\")\n", " \n", " return {}\n", "\n", "def main():\n", " print(f\"šŸš€ Starting RVL-CDIP Downloader (Disk Optimized)\")\n", " print(f\" Target Folder: {os.path.abspath(OUTPUT_DIR)}\")\n", " print(f\" Workers: {NUM_PROC}\")\n", " \n", " # Load Dataset\n", " print(\" Loading dataset structure from Hugging Face...\")\n", " dataset = load_dataset(\"chainyo/rvl-cdip\") \n", "\n", " # Setup Class Mapping\n", " labels_feature = dataset['train'].features['label']\n", " idx_to_class = {idx: name for idx, name in enumerate(labels_feature.names)}\n", " print(f\" Found {len(idx_to_class)} categories.\")\n", "\n", " # Disable Auto-Decoding (Prevents crashes on corrupt files)\n", " print(\" Configuring dataset for safe raw access...\")\n", " for split in SPLITS:\n", " dataset[split] = dataset[split].cast_column(\"image\", HFImage(decode=False))\n", "\n", " # Execute Parallel Processing\n", " for split in SPLITS:\n", " print(f\"\\nšŸ“¦ Processing SPLIT: {split.upper()}\")\n", " \n", " # We use remove_columns to ensure the output dataset is empty\n", " # This prevents the 50GB duplicate cache file.\n", " dataset[split].map(\n", " save_image_worker,\n", " batched=True,\n", " batch_size=100,\n", " with_indices=True,\n", " num_proc=NUM_PROC,\n", " remove_columns=dataset[split].column_names, \n", " fn_kwargs={\n", " 'split_name': split,\n", " 'output_root': OUTPUT_DIR,\n", " 'idx_to_class': idx_to_class\n", " },\n", " desc=f\"Saving {split}\"\n", " )\n", "\n", " print(f\"\\nāœ… Download and Extraction Complete!\")\n", " print(f\" You can now load this in PyTorch using:\")\n", " print(f\" datasets.ImageFolder(root='{OUTPUT_DIR}/train')\")\n", "\n", "if __name__ == \"__main__\":\n", " main()" ] }, { "cell_type": "markdown", "id": "c8530c8e", "metadata": {}, "source": [ "## Checking the Data Imbalance in ds (from HF)" ] }, { "cell_type": "code", "execution_count": 8, "id": "2785360c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SPLIT CLASS NAME COUNT STATUS\n", "------------------------------------------------------------\n", "TRAIN advertisement 19963 āŒ MISMATCH (Exp: 20000)\n", "TRAIN budget 20010 āŒ MISMATCH (Exp: 20000)\n", "TRAIN email 19954 āŒ MISMATCH (Exp: 20000)\n", "TRAIN file folder 20022 āŒ MISMATCH (Exp: 20000)\n", "TRAIN form 19957 āŒ MISMATCH (Exp: 20000)\n", "TRAIN handwritten 20034 āŒ MISMATCH (Exp: 20000)\n", "TRAIN invoice 19947 āŒ MISMATCH (Exp: 20000)\n", "TRAIN letter 20106 āŒ MISMATCH (Exp: 20000)\n", "TRAIN memo 19975 āŒ MISMATCH (Exp: 20000)\n", "TRAIN news article 20011 āŒ MISMATCH (Exp: 20000)\n", "TRAIN presentation 20043 āŒ MISMATCH (Exp: 20000)\n", "TRAIN questionnaire 20048 āŒ MISMATCH (Exp: 20000)\n", "TRAIN resume 20036 āŒ MISMATCH (Exp: 20000)\n", "TRAIN scientific publication 19902 āŒ MISMATCH (Exp: 20000)\n", "TRAIN scientific report 19994 āŒ MISMATCH (Exp: 20000)\n", "TRAIN specification 19997 āŒ MISMATCH (Exp: 20000)\n", "------------------------------------------------------------\n", "VAL advertisement 2522 āŒ MISMATCH (Exp: 2500)\n", "VAL budget 2485 āŒ MISMATCH (Exp: 2500)\n", "VAL email 2530 āŒ MISMATCH (Exp: 2500)\n", "VAL file folder 2451 āŒ MISMATCH (Exp: 2500)\n", "VAL form 2537 āŒ MISMATCH (Exp: 2500)\n", "VAL handwritten 2434 āŒ MISMATCH (Exp: 2500)\n", "VAL invoice 2576 āŒ MISMATCH (Exp: 2500)\n", "VAL letter 2430 āŒ MISMATCH (Exp: 2500)\n", "VAL memo 2533 āŒ MISMATCH (Exp: 2500)\n", "VAL news article 2526 āŒ MISMATCH (Exp: 2500)\n", "VAL presentation 2468 āŒ MISMATCH (Exp: 2500)\n", "VAL questionnaire 2517 āŒ MISMATCH (Exp: 2500)\n", "VAL resume 2426 āŒ MISMATCH (Exp: 2500)\n", "VAL scientific publication 2526 āŒ MISMATCH (Exp: 2500)\n", "VAL scientific report 2508 āŒ MISMATCH (Exp: 2500)\n", "VAL specification 2531 āŒ MISMATCH (Exp: 2500)\n", "------------------------------------------------------------\n", "TEST advertisement 2515 āŒ MISMATCH (Exp: 2500)\n", "TEST budget 2505 āŒ MISMATCH (Exp: 2500)\n", "TEST email 2516 āŒ MISMATCH (Exp: 2500)\n", "TEST file folder 2527 āŒ MISMATCH (Exp: 2500)\n", "TEST form 2506 āŒ MISMATCH (Exp: 2500)\n", "TEST handwritten 2532 āŒ MISMATCH (Exp: 2500)\n", "TEST invoice 2477 āŒ MISMATCH (Exp: 2500)\n", "TEST letter 2464 āŒ MISMATCH (Exp: 2500)\n", "TEST memo 2492 āŒ MISMATCH (Exp: 2500)\n", "TEST news article 2463 āŒ MISMATCH (Exp: 2500)\n", "TEST presentation 2489 āŒ MISMATCH (Exp: 2500)\n", "TEST questionnaire 2435 āŒ MISMATCH (Exp: 2500)\n", "TEST resume 2537 āŒ MISMATCH (Exp: 2500)\n", "TEST scientific publication 2572 āŒ MISMATCH (Exp: 2500)\n", "TEST scientific report 2498 āŒ MISMATCH (Exp: 2500)\n", "TEST specification 2472 āŒ MISMATCH (Exp: 2500)\n", "------------------------------------------------------------\n" ] } ], "source": [ "from collections import Counter\n", "import pandas as pd\n", "\n", "#Setup\n", "splits = ['train', 'val', 'test']\n", "label_feature = ds['train'].features['label']\n", "int2str = label_feature.int2str \n", "\n", "print(f\"{'SPLIT':<10} {'CLASS NAME':<25} {'COUNT':<10} {'STATUS'}\")\n", "print(\"-\" * 60)\n", "\n", "for split in splits:\n", " # Get all labels (Load only the label column into memory)\n", " # This is instant compared to loading images\n", " labels = ds[split]['label']\n", " \n", " # Count frequencies\n", " counts = Counter(labels)\n", " \n", " # Analyze each class\n", " # We sort by class ID to keep it organized\n", " for label_id in sorted(counts.keys()):\n", " count = counts[label_id]\n", " class_name = int2str(label_id)\n", " \n", " # Define Expected Counts based on the Paper\n", " # Train: 320k / 16 = 20,000\n", " # Test/Val: 40k / 16 = 2,500\n", " if split == 'train':\n", " expected = 20000\n", " else:\n", " expected = 2500\n", " \n", " status = \"āœ… OK\" if count == expected else f\"āŒ MISMATCH (Exp: {expected})\"\n", " \n", " print(f\"{split.upper():<10} {class_name:<25} {count:<10} {status}\")\n", " \n", " print(\"-\" * 60) " ] }, { "cell_type": "markdown", "id": "5f7b75a2", "metadata": {}, "source": [ "## Checking the data imbalance in \"rvl_cdip_data\" dir" ] }, { "cell_type": "code", "execution_count": 9, "id": "059bfaa5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "šŸ“‚ Scanning directory: /Users/arpit-zstch1557/Projects/document-classification/rvl_cdip_data\n", "SPLIT CLASS NAME FILES STATUS\n", "-----------------------------------------------------------------\n", "TRAIN advertisement 19963 āŒ MISMATCH (Exp: 20000)\n", "TRAIN budget 20010 āŒ MISMATCH (Exp: 20000)\n", "TRAIN email 19954 āŒ MISMATCH (Exp: 20000)\n", "TRAIN file folder 20022 āŒ MISMATCH (Exp: 20000)\n", "TRAIN form 19957 āŒ MISMATCH (Exp: 20000)\n", "TRAIN handwritten 20034 āŒ MISMATCH (Exp: 20000)\n", "TRAIN invoice 19947 āŒ MISMATCH (Exp: 20000)\n", "TRAIN letter 20106 āŒ MISMATCH (Exp: 20000)\n", "TRAIN memo 19975 āŒ MISMATCH (Exp: 20000)\n", "TRAIN news article 20011 āŒ MISMATCH (Exp: 20000)\n", "TRAIN presentation 20043 āŒ MISMATCH (Exp: 20000)\n", "TRAIN questionnaire 20048 āŒ MISMATCH (Exp: 20000)\n", "TRAIN resume 20036 āŒ MISMATCH (Exp: 20000)\n", "TRAIN scientific publication 19902 āŒ MISMATCH (Exp: 20000)\n", "TRAIN scientific report 19994 āŒ MISMATCH (Exp: 20000)\n", "TRAIN specification 19997 āš ļø OK (Off by 3)\n", "-----------------------------------------------------------------\n", "VAL advertisement 2522 āŒ MISMATCH (Exp: 2500)\n", "VAL budget 2485 āŒ MISMATCH (Exp: 2500)\n", "VAL email 2530 āŒ MISMATCH (Exp: 2500)\n", "VAL file folder 2451 āŒ MISMATCH (Exp: 2500)\n", "VAL form 2537 āŒ MISMATCH (Exp: 2500)\n", "VAL handwritten 2434 āŒ MISMATCH (Exp: 2500)\n", "VAL invoice 2576 āŒ MISMATCH (Exp: 2500)\n", "VAL letter 2430 āŒ MISMATCH (Exp: 2500)\n", "VAL memo 2533 āŒ MISMATCH (Exp: 2500)\n", "VAL news article 2526 āŒ MISMATCH (Exp: 2500)\n", "VAL presentation 2468 āŒ MISMATCH (Exp: 2500)\n", "VAL questionnaire 2517 āŒ MISMATCH (Exp: 2500)\n", "VAL resume 2426 āŒ MISMATCH (Exp: 2500)\n", "VAL scientific publication 2526 āŒ MISMATCH (Exp: 2500)\n", "VAL scientific report 2508 āŒ MISMATCH (Exp: 2500)\n", "VAL specification 2531 āŒ MISMATCH (Exp: 2500)\n", "-----------------------------------------------------------------\n", "TEST advertisement 2515 āŒ MISMATCH (Exp: 2500)\n", "TEST budget 2505 āŒ MISMATCH (Exp: 2500)\n", "TEST email 2516 āŒ MISMATCH (Exp: 2500)\n", "TEST file folder 2527 āŒ MISMATCH (Exp: 2500)\n", "TEST form 2506 āŒ MISMATCH (Exp: 2500)\n", "TEST handwritten 2532 āŒ MISMATCH (Exp: 2500)\n", "TEST invoice 2477 āŒ MISMATCH (Exp: 2500)\n", "TEST letter 2464 āŒ MISMATCH (Exp: 2500)\n", "TEST memo 2492 āŒ MISMATCH (Exp: 2500)\n", "TEST news article 2463 āŒ MISMATCH (Exp: 2500)\n", "TEST presentation 2489 āŒ MISMATCH (Exp: 2500)\n", "TEST questionnaire 2435 āŒ MISMATCH (Exp: 2500)\n", "TEST resume 2537 āŒ MISMATCH (Exp: 2500)\n", "TEST scientific publication 2571 āŒ MISMATCH (Exp: 2500)\n", "TEST scientific report 2498 āš ļø OK (Off by 2)\n", "TEST specification 2472 āŒ MISMATCH (Exp: 2500)\n", "-----------------------------------------------------------------\n", "\n", "Analysis Complete.\n" ] } ], "source": [ "import os\n", "import pandas as pd\n", "\n", "# Configuration\n", "DATA_DIR = \"rvl_cdip_data\" # Your directory name\n", "splits = ['train', 'val', 'test']\n", "\n", "print(f\"šŸ“‚ Scanning directory: {os.path.abspath(DATA_DIR)}\")\n", "print(f\"{'SPLIT':<10} {'CLASS NAME':<25} {'FILES':<10} {'STATUS'}\")\n", "print(\"-\" * 65)\n", "\n", "stats = []\n", "\n", "for split in splits:\n", " split_dir = os.path.join(DATA_DIR, split)\n", " \n", " # Check if split folder exists\n", " if not os.path.exists(split_dir):\n", " print(f\"āŒ Missing folder: {split}\")\n", " continue\n", " \n", " # Get all class folders (sorted for consistency)\n", " try:\n", " classes = sorted([d for d in os.listdir(split_dir) if os.path.isdir(os.path.join(split_dir, d))])\n", " except OSError:\n", " print(f\"āŒ Error reading {split} directory.\")\n", " continue\n", "\n", " for cls in classes:\n", " cls_path = os.path.join(split_dir, cls)\n", " \n", " # Count actual files (ignoring hidden files like .DS_Store)\n", " file_count = len([\n", " name for name in os.listdir(cls_path) \n", " if os.path.isfile(os.path.join(cls_path, name)) \n", " and not name.startswith('.')\n", " ])\n", " \n", " # Determine Expected Count based on the paper\n", " if split == 'train':\n", " expected = 20000 \n", " else:\n", " expected = 2500\n", "\n", " # Status Check\n", " if file_count == expected:\n", " status = \"āœ… OK\"\n", " elif abs(file_count - expected) < 5: \n", " # If off by 1-4 files (like the corrupt one we skipped), it's acceptable\n", " status = f\"āš ļø OK (Off by {expected - file_count})\"\n", " else:\n", " status = f\"āŒ MISMATCH (Exp: {expected})\"\n", " \n", " print(f\"{split.upper():<10} {cls:<25} {file_count:<10} {status}\")\n", " \n", " print(\"-\" * 65)\n", "\n", "print(\"\\nAnalysis Complete.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "4ef697a2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lab_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.11" } }, "nbformat": 4, "nbformat_minor": 5 }