import os import glob import zipfile import requests import argparse from PIL import Image from tqdm import tqdm # --- Helper Functions --- def download_file(url, dest_path, chunk_size=8192): """Downloads a file from a URL to a destination path with progress bar.""" try: response = requests.get(url, stream=True, timeout=30) # Added timeout response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) total_size = int(response.headers.get('content-length', 0)) print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...") with open(dest_path, 'wb') as f, tqdm( desc=os.path.basename(dest_path), total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as bar: for chunk in response.iter_content(chunk_size=chunk_size): size = f.write(chunk) bar.update(size) print(f"Download complete: {dest_path}") return True except requests.exceptions.RequestException as e: print(f"Error downloading {url}: {e}") # Clean up partially downloaded file if it exists if os.path.exists(dest_path): os.remove(dest_path) return False except Exception as e: print(f"An unexpected error occurred during download: {e}") if os.path.exists(dest_path): os.remove(dest_path) return False def unzip_file(zip_path, extract_to): """Unzips a file to a specified directory.""" print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...") try: with zipfile.ZipFile(zip_path, 'r') as zip_ref: # You could add a progress bar here for large zips if needed # using zip_ref.infolist() and iterating extraction, but # extractall is usually efficient enough. zip_ref.extractall(extract_to) print("Extraction complete.") return True except zipfile.BadZipFile: print(f"Error: Invalid or corrupted zip file: {zip_path}") return False except Exception as e: print(f"An error occurred during extraction: {e}") return False def find_image_dir(base_path, expected_subdir_suffix='_HR'): """ Tries to find the actual directory containing images after extraction. Handles cases where unzip creates an extra top-level folder. """ # Check if images are directly in base_path if glob.glob(os.path.join(base_path, '*.png')) or \ glob.glob(os.path.join(base_path, '*.jpg')) or \ glob.glob(os.path.join(base_path, '*.jpeg')): return base_path # Check common pattern: base_path/DatasetName_HR/ potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)] if len(potential_dirs) == 1: subdir = potential_dirs[0] # Check if this subdir contains images or ends with the expected suffix if subdir.endswith(expected_subdir_suffix) or \ glob.glob(os.path.join(subdir, '*.png')) or \ glob.glob(os.path.join(subdir, '*.jpg')) or \ glob.glob(os.path.join(subdir, '*.jpeg')): print(f"Found image directory: {subdir}") return subdir # Fallback if specific pattern not found, maybe it's still just base_path print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.") # If we found exactly one directory, return that, otherwise return the original path return potential_dirs[0] if len(potential_dirs) == 1 else base_path def downsample_images(hr_dir, lr_dir, scale_factor): """Downsamples HR images using bicubic interpolation.""" if not os.path.exists(lr_dir): os.makedirs(lr_dir) print(f"Created LR directory: {lr_dir}") hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \ glob.glob(os.path.join(hr_dir, '*.jpg')) + \ glob.glob(os.path.join(hr_dir, '*.jpeg')) if not hr_images: print(f"Error: No images found in the determined HR directory: {hr_dir}") return False print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...") processed_count = 0 for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"): try: hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB hr_width, hr_height = hr_img.size lr_width = hr_width // scale_factor lr_height = hr_height // scale_factor if lr_width == 0 or lr_height == 0: print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.") continue lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC) base_name = os.path.basename(hr_path) lr_save_path = os.path.join(lr_dir, base_name) lr_img.save(lr_save_path) processed_count += 1 except Exception as e: print(f"\nError processing {hr_path}: {e}") print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.") return processed_count > 0 # --- Main Execution --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.") parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).') parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.') parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.') parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).') parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.') args = parser.parse_args() # --- Define Paths --- dataset_base_path = os.path.join(args.base_dir, args.dataset_name) zip_filename = os.path.basename(args.url) zip_save_path = os.path.join(dataset_base_path, zip_filename) hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location # We will determine the *actual* HR image dir after extraction lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention print(f"--- Configuration ---") print(f"Dataset URL: {args.url}") print(f"Base Directory: {args.base_dir}") print(f"Dataset Name: {args.dataset_name}") print(f"Target Scale: x{args.scale}") print(f"Zip Save Path: {zip_save_path}") print(f"Initial Extract Path: {hr_extract_base}") print(f"LR Save Path: {lr_save_dir}") print(f"Force Re-run: {args.force}") print(f"--------------------") # --- Create Base Directory --- os.makedirs(dataset_base_path, exist_ok=True) # --- Step 1: Download --- hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists download_needed = not os.path.exists(zip_save_path) or args.force if download_needed: if args.force and os.path.exists(zip_save_path): print("Force enabled: Removing existing zip file...") os.remove(zip_save_path) if not download_file(args.url, zip_save_path): print("Exiting due to download failure.") exit(1) elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced print("Zip file already exists. Skipping download (use --force to override).") else: # Zip exists but HR dir doesn't - need to unzip print("Zip file found, but extraction directory missing. Will proceed to unzip.") # --- Step 2: Unzip --- # Check if the *potential* content directory already exists. Be a bit lenient here. # A more robust check would be to look inside the zip first or check for specific files. unzip_needed = not hr_dir_exists or args.force actual_hr_dir = None # Will store the path to the actual images if unzip_needed: if args.force and hr_dir_exists: print("Force enabled: Removing existing extraction directory...") import shutil shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents if not os.path.exists(zip_save_path): print("Error: Zip file not found, cannot unzip. Please check download step or path.") exit(1) os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists if not unzip_file(zip_save_path, hr_extract_base): print("Exiting due to extraction failure.") exit(1) # Find the actual directory containing images post-extraction actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))): print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.") exit(1) print(f"Located HR images in: {actual_hr_dir}") else: print("HR extraction directory already exists. Skipping unzip (use --force to override).") # Try to find the HR dir even if we skipped unzipping actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') if not actual_hr_dir: print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.") exit(1) print(f"Using existing HR images from: {actual_hr_dir}") # --- Step 3: Process (Downsample) --- lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0 processing_needed = not lr_dir_exists_and_populated or args.force if processing_needed: if args.force and lr_dir_exists_and_populated: print("Force enabled: Removing existing LR directory...") import shutil shutil.rmtree(lr_save_dir) # Careful! if not actual_hr_dir: print("Error: Cannot proceed with downsampling, HR image directory not determined.") exit(1) if not downsample_images(actual_hr_dir, lr_save_dir, args.scale): print("Downsampling process failed or produced no images.") # Optionally exit here depending on desired behavior # exit(1) else: print("Downsampling finished successfully.") else: print("LR directory already exists and is populated. Skipping downsampling (use --force to override).") print("\n--- Script Finished ---") print(f"HR images should be available in/under: {actual_hr_dir}") print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}") print("You can now use these directories with the SRDataset class.")