OxO_Image-Repair / prep.py
Gordon-H's picture
Upload 13 files
fd5c0a6 verified
import os
import glob
import zipfile
import requests
import argparse
from PIL import Image
from tqdm import tqdm
# --- Helper Functions ---
def download_file(url, dest_path, chunk_size=8192):
"""Downloads a file from a URL to a destination path with progress bar."""
try:
response = requests.get(url, stream=True, timeout=30) # Added timeout
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
total_size = int(response.headers.get('content-length', 0))
print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...")
with open(dest_path, 'wb') as f, tqdm(
desc=os.path.basename(dest_path),
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
bar.update(size)
print(f"Download complete: {dest_path}")
return True
except requests.exceptions.RequestException as e:
print(f"Error downloading {url}: {e}")
# Clean up partially downloaded file if it exists
if os.path.exists(dest_path):
os.remove(dest_path)
return False
except Exception as e:
print(f"An unexpected error occurred during download: {e}")
if os.path.exists(dest_path):
os.remove(dest_path)
return False
def unzip_file(zip_path, extract_to):
"""Unzips a file to a specified directory."""
print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...")
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# You could add a progress bar here for large zips if needed
# using zip_ref.infolist() and iterating extraction, but
# extractall is usually efficient enough.
zip_ref.extractall(extract_to)
print("Extraction complete.")
return True
except zipfile.BadZipFile:
print(f"Error: Invalid or corrupted zip file: {zip_path}")
return False
except Exception as e:
print(f"An error occurred during extraction: {e}")
return False
def find_image_dir(base_path, expected_subdir_suffix='_HR'):
"""
Tries to find the actual directory containing images after extraction.
Handles cases where unzip creates an extra top-level folder.
"""
# Check if images are directly in base_path
if glob.glob(os.path.join(base_path, '*.png')) or \
glob.glob(os.path.join(base_path, '*.jpg')) or \
glob.glob(os.path.join(base_path, '*.jpeg')):
return base_path
# Check common pattern: base_path/DatasetName_HR/
potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)]
if len(potential_dirs) == 1:
subdir = potential_dirs[0]
# Check if this subdir contains images or ends with the expected suffix
if subdir.endswith(expected_subdir_suffix) or \
glob.glob(os.path.join(subdir, '*.png')) or \
glob.glob(os.path.join(subdir, '*.jpg')) or \
glob.glob(os.path.join(subdir, '*.jpeg')):
print(f"Found image directory: {subdir}")
return subdir
# Fallback if specific pattern not found, maybe it's still just base_path
print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.")
# If we found exactly one directory, return that, otherwise return the original path
return potential_dirs[0] if len(potential_dirs) == 1 else base_path
def downsample_images(hr_dir, lr_dir, scale_factor):
"""Downsamples HR images using bicubic interpolation."""
if not os.path.exists(lr_dir):
os.makedirs(lr_dir)
print(f"Created LR directory: {lr_dir}")
hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \
glob.glob(os.path.join(hr_dir, '*.jpg')) + \
glob.glob(os.path.join(hr_dir, '*.jpeg'))
if not hr_images:
print(f"Error: No images found in the determined HR directory: {hr_dir}")
return False
print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...")
processed_count = 0
for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"):
try:
hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB
hr_width, hr_height = hr_img.size
lr_width = hr_width // scale_factor
lr_height = hr_height // scale_factor
if lr_width == 0 or lr_height == 0:
print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.")
continue
lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC)
base_name = os.path.basename(hr_path)
lr_save_path = os.path.join(lr_dir, base_name)
lr_img.save(lr_save_path)
processed_count += 1
except Exception as e:
print(f"\nError processing {hr_path}: {e}")
print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.")
return processed_count > 0
# --- Main Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.")
parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).')
parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.')
parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.')
parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).')
parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.')
args = parser.parse_args()
# --- Define Paths ---
dataset_base_path = os.path.join(args.base_dir, args.dataset_name)
zip_filename = os.path.basename(args.url)
zip_save_path = os.path.join(dataset_base_path, zip_filename)
hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location
# We will determine the *actual* HR image dir after extraction
lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention
print(f"--- Configuration ---")
print(f"Dataset URL: {args.url}")
print(f"Base Directory: {args.base_dir}")
print(f"Dataset Name: {args.dataset_name}")
print(f"Target Scale: x{args.scale}")
print(f"Zip Save Path: {zip_save_path}")
print(f"Initial Extract Path: {hr_extract_base}")
print(f"LR Save Path: {lr_save_dir}")
print(f"Force Re-run: {args.force}")
print(f"--------------------")
# --- Create Base Directory ---
os.makedirs(dataset_base_path, exist_ok=True)
# --- Step 1: Download ---
hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists
download_needed = not os.path.exists(zip_save_path) or args.force
if download_needed:
if args.force and os.path.exists(zip_save_path):
print("Force enabled: Removing existing zip file...")
os.remove(zip_save_path)
if not download_file(args.url, zip_save_path):
print("Exiting due to download failure.")
exit(1)
elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced
print("Zip file already exists. Skipping download (use --force to override).")
else: # Zip exists but HR dir doesn't - need to unzip
print("Zip file found, but extraction directory missing. Will proceed to unzip.")
# --- Step 2: Unzip ---
# Check if the *potential* content directory already exists. Be a bit lenient here.
# A more robust check would be to look inside the zip first or check for specific files.
unzip_needed = not hr_dir_exists or args.force
actual_hr_dir = None # Will store the path to the actual images
if unzip_needed:
if args.force and hr_dir_exists:
print("Force enabled: Removing existing extraction directory...")
import shutil
shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents
if not os.path.exists(zip_save_path):
print("Error: Zip file not found, cannot unzip. Please check download step or path.")
exit(1)
os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists
if not unzip_file(zip_save_path, hr_extract_base):
print("Exiting due to extraction failure.")
exit(1)
# Find the actual directory containing images post-extraction
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR
if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))):
print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.")
exit(1)
print(f"Located HR images in: {actual_hr_dir}")
else:
print("HR extraction directory already exists. Skipping unzip (use --force to override).")
# Try to find the HR dir even if we skipped unzipping
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR')
if not actual_hr_dir:
print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.")
exit(1)
print(f"Using existing HR images from: {actual_hr_dir}")
# --- Step 3: Process (Downsample) ---
lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0
processing_needed = not lr_dir_exists_and_populated or args.force
if processing_needed:
if args.force and lr_dir_exists_and_populated:
print("Force enabled: Removing existing LR directory...")
import shutil
shutil.rmtree(lr_save_dir) # Careful!
if not actual_hr_dir:
print("Error: Cannot proceed with downsampling, HR image directory not determined.")
exit(1)
if not downsample_images(actual_hr_dir, lr_save_dir, args.scale):
print("Downsampling process failed or produced no images.")
# Optionally exit here depending on desired behavior
# exit(1)
else:
print("Downsampling finished successfully.")
else:
print("LR directory already exists and is populated. Skipping downsampling (use --force to override).")
print("\n--- Script Finished ---")
print(f"HR images should be available in/under: {actual_hr_dir}")
print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}")
print("You can now use these directories with the SRDataset class.")