|
|
import os |
|
|
import glob |
|
|
import zipfile |
|
|
import requests |
|
|
import argparse |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
def download_file(url, dest_path, chunk_size=8192): |
|
|
"""Downloads a file from a URL to a destination path with progress bar.""" |
|
|
try: |
|
|
response = requests.get(url, stream=True, timeout=30) |
|
|
response.raise_for_status() |
|
|
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...") |
|
|
with open(dest_path, 'wb') as f, tqdm( |
|
|
desc=os.path.basename(dest_path), |
|
|
total=total_size, |
|
|
unit='iB', |
|
|
unit_scale=True, |
|
|
unit_divisor=1024, |
|
|
) as bar: |
|
|
for chunk in response.iter_content(chunk_size=chunk_size): |
|
|
size = f.write(chunk) |
|
|
bar.update(size) |
|
|
print(f"Download complete: {dest_path}") |
|
|
return True |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error downloading {url}: {e}") |
|
|
|
|
|
if os.path.exists(dest_path): |
|
|
os.remove(dest_path) |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"An unexpected error occurred during download: {e}") |
|
|
if os.path.exists(dest_path): |
|
|
os.remove(dest_path) |
|
|
return False |
|
|
|
|
|
|
|
|
def unzip_file(zip_path, extract_to): |
|
|
"""Unzips a file to a specified directory.""" |
|
|
print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...") |
|
|
try: |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
|
|
|
|
|
|
|
|
|
zip_ref.extractall(extract_to) |
|
|
print("Extraction complete.") |
|
|
return True |
|
|
except zipfile.BadZipFile: |
|
|
print(f"Error: Invalid or corrupted zip file: {zip_path}") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"An error occurred during extraction: {e}") |
|
|
return False |
|
|
|
|
|
def find_image_dir(base_path, expected_subdir_suffix='_HR'): |
|
|
""" |
|
|
Tries to find the actual directory containing images after extraction. |
|
|
Handles cases where unzip creates an extra top-level folder. |
|
|
""" |
|
|
|
|
|
if glob.glob(os.path.join(base_path, '*.png')) or \ |
|
|
glob.glob(os.path.join(base_path, '*.jpg')) or \ |
|
|
glob.glob(os.path.join(base_path, '*.jpeg')): |
|
|
return base_path |
|
|
|
|
|
|
|
|
potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)] |
|
|
if len(potential_dirs) == 1: |
|
|
subdir = potential_dirs[0] |
|
|
|
|
|
if subdir.endswith(expected_subdir_suffix) or \ |
|
|
glob.glob(os.path.join(subdir, '*.png')) or \ |
|
|
glob.glob(os.path.join(subdir, '*.jpg')) or \ |
|
|
glob.glob(os.path.join(subdir, '*.jpeg')): |
|
|
print(f"Found image directory: {subdir}") |
|
|
return subdir |
|
|
|
|
|
|
|
|
print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.") |
|
|
|
|
|
return potential_dirs[0] if len(potential_dirs) == 1 else base_path |
|
|
|
|
|
|
|
|
def downsample_images(hr_dir, lr_dir, scale_factor): |
|
|
"""Downsamples HR images using bicubic interpolation.""" |
|
|
if not os.path.exists(lr_dir): |
|
|
os.makedirs(lr_dir) |
|
|
print(f"Created LR directory: {lr_dir}") |
|
|
|
|
|
hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \ |
|
|
glob.glob(os.path.join(hr_dir, '*.jpg')) + \ |
|
|
glob.glob(os.path.join(hr_dir, '*.jpeg')) |
|
|
|
|
|
if not hr_images: |
|
|
print(f"Error: No images found in the determined HR directory: {hr_dir}") |
|
|
return False |
|
|
|
|
|
print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...") |
|
|
|
|
|
processed_count = 0 |
|
|
for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"): |
|
|
try: |
|
|
hr_img = Image.open(hr_path).convert('RGB') |
|
|
hr_width, hr_height = hr_img.size |
|
|
|
|
|
lr_width = hr_width // scale_factor |
|
|
lr_height = hr_height // scale_factor |
|
|
|
|
|
if lr_width == 0 or lr_height == 0: |
|
|
print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.") |
|
|
continue |
|
|
|
|
|
lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC) |
|
|
|
|
|
base_name = os.path.basename(hr_path) |
|
|
lr_save_path = os.path.join(lr_dir, base_name) |
|
|
lr_img.save(lr_save_path) |
|
|
processed_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError processing {hr_path}: {e}") |
|
|
|
|
|
print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.") |
|
|
return processed_count > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.") |
|
|
parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).') |
|
|
parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.') |
|
|
parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.') |
|
|
parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).') |
|
|
parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
dataset_base_path = os.path.join(args.base_dir, args.dataset_name) |
|
|
zip_filename = os.path.basename(args.url) |
|
|
zip_save_path = os.path.join(dataset_base_path, zip_filename) |
|
|
hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') |
|
|
|
|
|
lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') |
|
|
|
|
|
print(f"--- Configuration ---") |
|
|
print(f"Dataset URL: {args.url}") |
|
|
print(f"Base Directory: {args.base_dir}") |
|
|
print(f"Dataset Name: {args.dataset_name}") |
|
|
print(f"Target Scale: x{args.scale}") |
|
|
print(f"Zip Save Path: {zip_save_path}") |
|
|
print(f"Initial Extract Path: {hr_extract_base}") |
|
|
print(f"LR Save Path: {lr_save_dir}") |
|
|
print(f"Force Re-run: {args.force}") |
|
|
print(f"--------------------") |
|
|
|
|
|
|
|
|
os.makedirs(dataset_base_path, exist_ok=True) |
|
|
|
|
|
|
|
|
hr_dir_exists = os.path.isdir(hr_extract_base) |
|
|
download_needed = not os.path.exists(zip_save_path) or args.force |
|
|
|
|
|
if download_needed: |
|
|
if args.force and os.path.exists(zip_save_path): |
|
|
print("Force enabled: Removing existing zip file...") |
|
|
os.remove(zip_save_path) |
|
|
if not download_file(args.url, zip_save_path): |
|
|
print("Exiting due to download failure.") |
|
|
exit(1) |
|
|
elif hr_dir_exists: |
|
|
print("Zip file already exists. Skipping download (use --force to override).") |
|
|
else: |
|
|
print("Zip file found, but extraction directory missing. Will proceed to unzip.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unzip_needed = not hr_dir_exists or args.force |
|
|
|
|
|
actual_hr_dir = None |
|
|
|
|
|
if unzip_needed: |
|
|
if args.force and hr_dir_exists: |
|
|
print("Force enabled: Removing existing extraction directory...") |
|
|
import shutil |
|
|
shutil.rmtree(hr_extract_base) |
|
|
|
|
|
if not os.path.exists(zip_save_path): |
|
|
print("Error: Zip file not found, cannot unzip. Please check download step or path.") |
|
|
exit(1) |
|
|
|
|
|
os.makedirs(hr_extract_base, exist_ok=True) |
|
|
if not unzip_file(zip_save_path, hr_extract_base): |
|
|
print("Exiting due to extraction failure.") |
|
|
exit(1) |
|
|
|
|
|
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') |
|
|
if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))): |
|
|
print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.") |
|
|
exit(1) |
|
|
print(f"Located HR images in: {actual_hr_dir}") |
|
|
|
|
|
else: |
|
|
print("HR extraction directory already exists. Skipping unzip (use --force to override).") |
|
|
|
|
|
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') |
|
|
if not actual_hr_dir: |
|
|
print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.") |
|
|
exit(1) |
|
|
print(f"Using existing HR images from: {actual_hr_dir}") |
|
|
|
|
|
|
|
|
|
|
|
lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0 |
|
|
processing_needed = not lr_dir_exists_and_populated or args.force |
|
|
|
|
|
if processing_needed: |
|
|
if args.force and lr_dir_exists_and_populated: |
|
|
print("Force enabled: Removing existing LR directory...") |
|
|
import shutil |
|
|
shutil.rmtree(lr_save_dir) |
|
|
|
|
|
if not actual_hr_dir: |
|
|
print("Error: Cannot proceed with downsampling, HR image directory not determined.") |
|
|
exit(1) |
|
|
|
|
|
if not downsample_images(actual_hr_dir, lr_save_dir, args.scale): |
|
|
print("Downsampling process failed or produced no images.") |
|
|
|
|
|
|
|
|
else: |
|
|
print("Downsampling finished successfully.") |
|
|
else: |
|
|
print("LR directory already exists and is populated. Skipping downsampling (use --force to override).") |
|
|
|
|
|
|
|
|
print("\n--- Script Finished ---") |
|
|
print(f"HR images should be available in/under: {actual_hr_dir}") |
|
|
print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}") |
|
|
print("You can now use these directories with the SRDataset class.") |