File size: 11,335 Bytes
fd5c0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import os
import glob
import zipfile
import requests
import argparse
from PIL import Image
from tqdm import tqdm

# --- Helper Functions ---

def download_file(url, dest_path, chunk_size=8192):
    """Downloads a file from a URL to a destination path with progress bar."""
    try:
        response = requests.get(url, stream=True, timeout=30) # Added timeout
        response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)

        total_size = int(response.headers.get('content-length', 0))
        
        print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...")
        with open(dest_path, 'wb') as f, tqdm(
            desc=os.path.basename(dest_path),
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=chunk_size):
                size = f.write(chunk)
                bar.update(size)
        print(f"Download complete: {dest_path}")
        return True

    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")
        # Clean up partially downloaded file if it exists
        if os.path.exists(dest_path):
            os.remove(dest_path)
        return False
    except Exception as e:
        print(f"An unexpected error occurred during download: {e}")
        if os.path.exists(dest_path):
            os.remove(dest_path)
        return False


def unzip_file(zip_path, extract_to):
    """Unzips a file to a specified directory."""
    print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # You could add a progress bar here for large zips if needed
            # using zip_ref.infolist() and iterating extraction, but
            # extractall is usually efficient enough.
            zip_ref.extractall(extract_to)
        print("Extraction complete.")
        return True
    except zipfile.BadZipFile:
        print(f"Error: Invalid or corrupted zip file: {zip_path}")
        return False
    except Exception as e:
        print(f"An error occurred during extraction: {e}")
        return False

def find_image_dir(base_path, expected_subdir_suffix='_HR'):
    """
    Tries to find the actual directory containing images after extraction.
    Handles cases where unzip creates an extra top-level folder.
    """
    # Check if images are directly in base_path
    if glob.glob(os.path.join(base_path, '*.png')) or \
       glob.glob(os.path.join(base_path, '*.jpg')) or \
       glob.glob(os.path.join(base_path, '*.jpeg')):
        return base_path

    # Check common pattern: base_path/DatasetName_HR/
    potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)]
    if len(potential_dirs) == 1:
         subdir = potential_dirs[0]
         # Check if this subdir contains images or ends with the expected suffix
         if subdir.endswith(expected_subdir_suffix) or \
            glob.glob(os.path.join(subdir, '*.png')) or \
            glob.glob(os.path.join(subdir, '*.jpg')) or \
            glob.glob(os.path.join(subdir, '*.jpeg')):
             print(f"Found image directory: {subdir}")
             return subdir

    # Fallback if specific pattern not found, maybe it's still just base_path
    print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.")
    # If we found exactly one directory, return that, otherwise return the original path
    return potential_dirs[0] if len(potential_dirs) == 1 else base_path


def downsample_images(hr_dir, lr_dir, scale_factor):
    """Downsamples HR images using bicubic interpolation."""
    if not os.path.exists(lr_dir):
        os.makedirs(lr_dir)
        print(f"Created LR directory: {lr_dir}")

    hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \
                glob.glob(os.path.join(hr_dir, '*.jpg')) + \
                glob.glob(os.path.join(hr_dir, '*.jpeg'))

    if not hr_images:
        print(f"Error: No images found in the determined HR directory: {hr_dir}")
        return False

    print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...")

    processed_count = 0
    for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"):
        try:
            hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB
            hr_width, hr_height = hr_img.size

            lr_width = hr_width // scale_factor
            lr_height = hr_height // scale_factor

            if lr_width == 0 or lr_height == 0:
                print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.")
                continue

            lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC)

            base_name = os.path.basename(hr_path)
            lr_save_path = os.path.join(lr_dir, base_name)
            lr_img.save(lr_save_path)
            processed_count += 1

        except Exception as e:
            print(f"\nError processing {hr_path}: {e}")

    print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.")
    return processed_count > 0


# --- Main Execution ---

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.")
    parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).')
    parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.')
    parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.')
    parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).')
    parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.')

    args = parser.parse_args()

    # --- Define Paths ---
    dataset_base_path = os.path.join(args.base_dir, args.dataset_name)
    zip_filename = os.path.basename(args.url)
    zip_save_path = os.path.join(dataset_base_path, zip_filename)
    hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location
    # We will determine the *actual* HR image dir after extraction
    lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention

    print(f"--- Configuration ---")
    print(f"Dataset URL: {args.url}")
    print(f"Base Directory: {args.base_dir}")
    print(f"Dataset Name: {args.dataset_name}")
    print(f"Target Scale: x{args.scale}")
    print(f"Zip Save Path: {zip_save_path}")
    print(f"Initial Extract Path: {hr_extract_base}")
    print(f"LR Save Path: {lr_save_dir}")
    print(f"Force Re-run: {args.force}")
    print(f"--------------------")

    # --- Create Base Directory ---
    os.makedirs(dataset_base_path, exist_ok=True)

    # --- Step 1: Download ---
    hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists
    download_needed = not os.path.exists(zip_save_path) or args.force

    if download_needed:
        if args.force and os.path.exists(zip_save_path):
            print("Force enabled: Removing existing zip file...")
            os.remove(zip_save_path)
        if not download_file(args.url, zip_save_path):
            print("Exiting due to download failure.")
            exit(1)
    elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced
         print("Zip file already exists. Skipping download (use --force to override).")
    else: # Zip exists but HR dir doesn't - need to unzip
        print("Zip file found, but extraction directory missing. Will proceed to unzip.")


    # --- Step 2: Unzip ---
    # Check if the *potential* content directory already exists. Be a bit lenient here.
    # A more robust check would be to look inside the zip first or check for specific files.
    unzip_needed = not hr_dir_exists or args.force

    actual_hr_dir = None # Will store the path to the actual images

    if unzip_needed:
        if args.force and hr_dir_exists:
             print("Force enabled: Removing existing extraction directory...")
             import shutil
             shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents

        if not os.path.exists(zip_save_path):
            print("Error: Zip file not found, cannot unzip. Please check download step or path.")
            exit(1)

        os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists
        if not unzip_file(zip_save_path, hr_extract_base):
             print("Exiting due to extraction failure.")
             exit(1)
        # Find the actual directory containing images post-extraction
        actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR
        if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))):
            print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.")
            exit(1)
        print(f"Located HR images in: {actual_hr_dir}")

    else:
        print("HR extraction directory already exists. Skipping unzip (use --force to override).")
        # Try to find the HR dir even if we skipped unzipping
        actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR')
        if not actual_hr_dir:
             print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.")
             exit(1)
        print(f"Using existing HR images from: {actual_hr_dir}")


    # --- Step 3: Process (Downsample) ---
    lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0
    processing_needed = not lr_dir_exists_and_populated or args.force

    if processing_needed:
        if args.force and lr_dir_exists_and_populated:
             print("Force enabled: Removing existing LR directory...")
             import shutil
             shutil.rmtree(lr_save_dir) # Careful!

        if not actual_hr_dir:
             print("Error: Cannot proceed with downsampling, HR image directory not determined.")
             exit(1)

        if not downsample_images(actual_hr_dir, lr_save_dir, args.scale):
             print("Downsampling process failed or produced no images.")
             # Optionally exit here depending on desired behavior
             # exit(1)
        else:
             print("Downsampling finished successfully.")
    else:
        print("LR directory already exists and is populated. Skipping downsampling (use --force to override).")


    print("\n--- Script Finished ---")
    print(f"HR images should be available in/under: {actual_hr_dir}")
    print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}")
    print("You can now use these directories with the SRDataset class.")