File size: 3,348 Bytes
dcd2bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import requests
import zipfile
import tarfile
from tqdm import tqdm

# Target directory for all datasets
BASE_DIR = "./datasets"

# URLs for the industry-standard high-res and benchmark datasets
DATASETS = {
    # High-Resolution Training & Validation (for H100s)
    
    #"Flickr2K": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar",
    
    # Standard Benchmark Test Sets (Hosted reliably on popular CV repos)
    "Test_Datasets": "https://github.com/cszn/FFDNet/archive/refs/heads/master.zip" 
}


def download_file(url, dest_path):
    """Downloads a file with a progress bar and robust error handling."""
    if os.path.exists(dest_path):
        print(f"[*] {os.path.basename(dest_path)} already exists. Skipping download.")
        return

    print(f"[*] Downloading {url}...")
    
    # Disguise the script as a standard web browser
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
    }
    
    try:
        response = requests.get(url, stream=True, headers=headers, timeout=30)
        response.raise_for_status()
        
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024 * 1024 # 1 MB chunks
        
        with open(dest_path, 'wb') as file, tqdm(
            total=total_size, unit='B', unit_scale=True, desc=os.path.basename(dest_path)
        ) as bar:
            for data in response.iter_content(block_size):
                file.write(data)
                bar.update(len(data))
                
    except requests.exceptions.RequestException as e:
        print(f"\n[!] The server rejected the connection: {e}")
        print(f"[!] Skipping {os.path.basename(dest_path)}. You can proceed without it.")
        # Remove the partial file if it failed midway
        if os.path.exists(dest_path):
            os.remove(dest_path)

def extract_file(file_path, extract_to):
    """Extracts zip or tar files."""
    print(f"[*] Extracting {os.path.basename(file_path)}...")
    
    if file_path.endswith(".zip"):
        with zipfile.ZipFile(file_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
    elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
        with tarfile.open(file_path, 'r:*') as tar_ref:
            tar_ref.extractall(extract_to)
    else:
        print(f"[!] Unknown file format for {file_path}")

def main():
    os.makedirs(BASE_DIR, exist_ok=True)
    
    for name, url in DATASETS.items():
        print(f"\n--- Processing {name} ---")
        
        # Determine file extension and destination paths
        ext = ".tar" if ".tar" in url else ".zip"
        file_name = f"{name}{ext}"
        download_path = os.path.join(BASE_DIR, file_name)
        
        # Download
        download_file(url, download_path)
        
        # Extract
        extract_dir = os.path.join(BASE_DIR, name)
        os.makedirs(extract_dir, exist_ok=True)
        extract_file(download_path, extract_dir)
        
        # Clean up the archive to save disk space
        print(f"[*] Cleaning up archive {file_name}...")
        os.remove(download_path)

    print("\n[+] All datasets downloaded and extracted successfully!")
    print(f"[+] Look inside the '{BASE_DIR}' folder.")

if __name__ == "__main__":
    main()