File size: 11,335 Bytes
fd5c0a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import os
import glob
import zipfile
import requests
import argparse
from PIL import Image
from tqdm import tqdm
# --- Helper Functions ---
def download_file(url, dest_path, chunk_size=8192):
"""Downloads a file from a URL to a destination path with progress bar."""
try:
response = requests.get(url, stream=True, timeout=30) # Added timeout
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
total_size = int(response.headers.get('content-length', 0))
print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...")
with open(dest_path, 'wb') as f, tqdm(
desc=os.path.basename(dest_path),
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
bar.update(size)
print(f"Download complete: {dest_path}")
return True
except requests.exceptions.RequestException as e:
print(f"Error downloading {url}: {e}")
# Clean up partially downloaded file if it exists
if os.path.exists(dest_path):
os.remove(dest_path)
return False
except Exception as e:
print(f"An unexpected error occurred during download: {e}")
if os.path.exists(dest_path):
os.remove(dest_path)
return False
def unzip_file(zip_path, extract_to):
"""Unzips a file to a specified directory."""
print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...")
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# You could add a progress bar here for large zips if needed
# using zip_ref.infolist() and iterating extraction, but
# extractall is usually efficient enough.
zip_ref.extractall(extract_to)
print("Extraction complete.")
return True
except zipfile.BadZipFile:
print(f"Error: Invalid or corrupted zip file: {zip_path}")
return False
except Exception as e:
print(f"An error occurred during extraction: {e}")
return False
def find_image_dir(base_path, expected_subdir_suffix='_HR'):
"""
Tries to find the actual directory containing images after extraction.
Handles cases where unzip creates an extra top-level folder.
"""
# Check if images are directly in base_path
if glob.glob(os.path.join(base_path, '*.png')) or \
glob.glob(os.path.join(base_path, '*.jpg')) or \
glob.glob(os.path.join(base_path, '*.jpeg')):
return base_path
# Check common pattern: base_path/DatasetName_HR/
potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)]
if len(potential_dirs) == 1:
subdir = potential_dirs[0]
# Check if this subdir contains images or ends with the expected suffix
if subdir.endswith(expected_subdir_suffix) or \
glob.glob(os.path.join(subdir, '*.png')) or \
glob.glob(os.path.join(subdir, '*.jpg')) or \
glob.glob(os.path.join(subdir, '*.jpeg')):
print(f"Found image directory: {subdir}")
return subdir
# Fallback if specific pattern not found, maybe it's still just base_path
print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.")
# If we found exactly one directory, return that, otherwise return the original path
return potential_dirs[0] if len(potential_dirs) == 1 else base_path
def downsample_images(hr_dir, lr_dir, scale_factor):
"""Downsamples HR images using bicubic interpolation."""
if not os.path.exists(lr_dir):
os.makedirs(lr_dir)
print(f"Created LR directory: {lr_dir}")
hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \
glob.glob(os.path.join(hr_dir, '*.jpg')) + \
glob.glob(os.path.join(hr_dir, '*.jpeg'))
if not hr_images:
print(f"Error: No images found in the determined HR directory: {hr_dir}")
return False
print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...")
processed_count = 0
for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"):
try:
hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB
hr_width, hr_height = hr_img.size
lr_width = hr_width // scale_factor
lr_height = hr_height // scale_factor
if lr_width == 0 or lr_height == 0:
print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.")
continue
lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC)
base_name = os.path.basename(hr_path)
lr_save_path = os.path.join(lr_dir, base_name)
lr_img.save(lr_save_path)
processed_count += 1
except Exception as e:
print(f"\nError processing {hr_path}: {e}")
print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.")
return processed_count > 0
# --- Main Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.")
parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).')
parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.')
parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.')
parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).')
parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.')
args = parser.parse_args()
# --- Define Paths ---
dataset_base_path = os.path.join(args.base_dir, args.dataset_name)
zip_filename = os.path.basename(args.url)
zip_save_path = os.path.join(dataset_base_path, zip_filename)
hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location
# We will determine the *actual* HR image dir after extraction
lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention
print(f"--- Configuration ---")
print(f"Dataset URL: {args.url}")
print(f"Base Directory: {args.base_dir}")
print(f"Dataset Name: {args.dataset_name}")
print(f"Target Scale: x{args.scale}")
print(f"Zip Save Path: {zip_save_path}")
print(f"Initial Extract Path: {hr_extract_base}")
print(f"LR Save Path: {lr_save_dir}")
print(f"Force Re-run: {args.force}")
print(f"--------------------")
# --- Create Base Directory ---
os.makedirs(dataset_base_path, exist_ok=True)
# --- Step 1: Download ---
hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists
download_needed = not os.path.exists(zip_save_path) or args.force
if download_needed:
if args.force and os.path.exists(zip_save_path):
print("Force enabled: Removing existing zip file...")
os.remove(zip_save_path)
if not download_file(args.url, zip_save_path):
print("Exiting due to download failure.")
exit(1)
elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced
print("Zip file already exists. Skipping download (use --force to override).")
else: # Zip exists but HR dir doesn't - need to unzip
print("Zip file found, but extraction directory missing. Will proceed to unzip.")
# --- Step 2: Unzip ---
# Check if the *potential* content directory already exists. Be a bit lenient here.
# A more robust check would be to look inside the zip first or check for specific files.
unzip_needed = not hr_dir_exists or args.force
actual_hr_dir = None # Will store the path to the actual images
if unzip_needed:
if args.force and hr_dir_exists:
print("Force enabled: Removing existing extraction directory...")
import shutil
shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents
if not os.path.exists(zip_save_path):
print("Error: Zip file not found, cannot unzip. Please check download step or path.")
exit(1)
os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists
if not unzip_file(zip_save_path, hr_extract_base):
print("Exiting due to extraction failure.")
exit(1)
# Find the actual directory containing images post-extraction
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR
if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))):
print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.")
exit(1)
print(f"Located HR images in: {actual_hr_dir}")
else:
print("HR extraction directory already exists. Skipping unzip (use --force to override).")
# Try to find the HR dir even if we skipped unzipping
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR')
if not actual_hr_dir:
print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.")
exit(1)
print(f"Using existing HR images from: {actual_hr_dir}")
# --- Step 3: Process (Downsample) ---
lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0
processing_needed = not lr_dir_exists_and_populated or args.force
if processing_needed:
if args.force and lr_dir_exists_and_populated:
print("Force enabled: Removing existing LR directory...")
import shutil
shutil.rmtree(lr_save_dir) # Careful!
if not actual_hr_dir:
print("Error: Cannot proceed with downsampling, HR image directory not determined.")
exit(1)
if not downsample_images(actual_hr_dir, lr_save_dir, args.scale):
print("Downsampling process failed or produced no images.")
# Optionally exit here depending on desired behavior
# exit(1)
else:
print("Downsampling finished successfully.")
else:
print("LR directory already exists and is populated. Skipping downsampling (use --force to override).")
print("\n--- Script Finished ---")
print(f"HR images should be available in/under: {actual_hr_dir}")
print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}")
print("You can now use these directories with the SRDataset class.") |