Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import glob | |
| import cv2 | |
| import numpy as np | |
| import random | |
| import math | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| import torch | |
| # Apply compatibility patch BEFORE importing basicsr | |
| # This fixes the issue where basicsr tries to import torchvision.transforms.functional_tensor | |
| # which doesn't exist in newer torchvision versions | |
| try: | |
| import torchvision.transforms.functional as F | |
| sys.modules['torchvision.transforms.functional_tensor'] = F | |
| except: | |
| pass | |
| # Import Real-ESRGAN's actual degradation pipeline | |
| from basicsr.data.degradations import ( | |
| random_add_gaussian_noise_pt, | |
| random_add_poisson_noise_pt, | |
| random_mixed_kernels, | |
| circular_lowpass_kernel | |
| ) | |
| from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
| from basicsr.utils import DiffJPEG, USMSharp | |
| from basicsr.utils.img_process_util import filter2D | |
| from basicsr.data.transforms import paired_random_crop | |
| from torch.nn import functional as F_torch | |
| class RealESRGANDegrader: | |
| """Real-ESRGAN degradation pipeline matching original ResShift implementation""" | |
| def __init__(self, scale=4): | |
| self.scale = scale | |
| # Initialize JPEG compression | |
| self.jpeger = DiffJPEG(differentiable=False) | |
| # Import all parameters from config | |
| from config import ( | |
| blur_kernel_size, kernel_list, kernel_prob, | |
| data_train_blur_sigma as blur_sigma, | |
| noise_range, poisson_scale_range, jpeg_range, | |
| data_train_blur_sigma2 as blur_sigma2, | |
| noise_range2, poisson_scale_range2, jpeg_range2, | |
| second_order_prob, second_blur_prob, final_sinc_prob, | |
| resize_prob, resize_range, resize_prob2, resize_range2, | |
| gaussian_noise_prob, gray_noise_prob, gaussian_noise_prob2, gray_noise_prob2, | |
| data_train_betag_range as betag_range, | |
| data_train_betap_range as betap_range, | |
| data_train_betag_range2 as betag_range2, | |
| data_train_betap_range2 as betap_range2, | |
| data_train_blur_kernel_size2 as blur_kernel_size2, | |
| data_train_sinc_prob as sinc_prob, | |
| data_train_sinc_prob2 as sinc_prob2 | |
| ) | |
| # Blur kernel settings | |
| self.blur_kernel_size = blur_kernel_size | |
| self.kernel_list = kernel_list | |
| self.kernel_prob = kernel_prob | |
| # First degradation parameters | |
| self.blur_sigma = blur_sigma | |
| self.noise_range = noise_range | |
| self.poisson_scale_range = poisson_scale_range | |
| self.jpeg_range = jpeg_range | |
| self.betag_range = betag_range | |
| self.betap_range = betap_range | |
| self.sinc_prob = sinc_prob | |
| # Second degradation parameters | |
| self.second_order_prob = second_order_prob | |
| self.second_blur_prob = second_blur_prob | |
| self.blur_kernel_size2 = blur_kernel_size2 | |
| self.blur_sigma2 = blur_sigma2 | |
| self.noise_range2 = noise_range2 | |
| self.poisson_scale_range2 = poisson_scale_range2 | |
| self.jpeg_range2 = jpeg_range2 | |
| self.betag_range2 = betag_range2 | |
| self.betap_range2 = betap_range2 | |
| self.sinc_prob2 = sinc_prob2 | |
| # Final sinc filter | |
| self.final_sinc_prob = final_sinc_prob | |
| # Resize parameters | |
| self.resize_prob = resize_prob | |
| self.resize_range = resize_range | |
| self.resize_prob2 = resize_prob2 | |
| self.resize_range2 = resize_range2 | |
| # Noise probabilities | |
| self.gaussian_noise_prob = gaussian_noise_prob | |
| self.gray_noise_prob = gray_noise_prob | |
| self.gaussian_noise_prob2 = gaussian_noise_prob2 | |
| self.gray_noise_prob2 = gray_noise_prob2 | |
| # Kernel ranges for sinc filter generation | |
| self.kernel_range1 = [x for x in range(3, self.blur_kernel_size, 2)] | |
| self.kernel_range2 = [x for x in range(3, self.blur_kernel_size2, 2)] | |
| # Pulse tensor (identity kernel) for final sinc filter | |
| self.pulse_tensor = torch.zeros(self.blur_kernel_size2, self.blur_kernel_size2).float() | |
| self.pulse_tensor[self.blur_kernel_size2//2, self.blur_kernel_size2//2] = 1 | |
| def degrade(self, img_gt): | |
| """ | |
| Apply Real-ESRGAN degradation | |
| Args: | |
| img_gt: torch tensor (C, H, W) in range [0, 1] (on GPU) | |
| Returns: | |
| img_lq: degraded tensor (on GPU) | |
| """ | |
| img_gt = img_gt.unsqueeze(0) # Add batch dimension [1, C, H, W] | |
| device = img_gt.device # Get the device (e.g., 'cuda:0') | |
| ori_h, ori_w = img_gt.size()[2:4] | |
| # ----------------------- The first degradation process ----------------------- # | |
| # 1. BLUR | |
| # Applies a random blur kernel (Gaussian, anisotropic, etc.) | |
| kernel = random_mixed_kernels( | |
| self.kernel_list, | |
| self.kernel_prob, | |
| self.blur_kernel_size, | |
| self.blur_sigma, # <-- Uses new [2.0, 8.0] range | |
| self.blur_sigma, | |
| [-np.pi, np.pi], | |
| self.betag_range, # <-- This will now work | |
| self.betap_range, # <-- This will now work | |
| noise_range=None | |
| ) | |
| if isinstance(kernel, np.ndarray): | |
| kernel = torch.FloatTensor(kernel).to(device) | |
| img_lq = filter2D(img_gt, kernel) | |
| # 2. RANDOM RESIZE (First degradation) | |
| updown_type = random.choices(['up', 'down', 'keep'], weights=self.resize_prob)[0] | |
| if updown_type == 'up': | |
| scale_factor = random.uniform(1, self.resize_range[1]) | |
| elif updown_type == 'down': | |
| scale_factor = random.uniform(self.resize_range[0], 1) | |
| else: | |
| scale_factor = 1 | |
| if scale_factor != 1: | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| img_lq = F_torch.interpolate(img_lq, scale_factor=scale_factor, mode=mode) | |
| # 3. NOISE (First degradation) | |
| if random.random() < self.gaussian_noise_prob: | |
| img_lq = random_add_gaussian_noise_pt( | |
| img_lq, | |
| sigma_range=self.noise_range, | |
| clip=True, | |
| rounds=False, | |
| gray_prob=self.gray_noise_prob | |
| ) | |
| else: | |
| img_lq = random_add_poisson_noise_pt( | |
| img_lq, | |
| scale_range=self.poisson_scale_range, | |
| gray_prob=self.gray_noise_prob, | |
| clip=True, | |
| rounds=False | |
| ) | |
| # 4. JPEG COMPRESSION (First degradation) | |
| jpeg_p = img_lq.new_zeros(img_lq.size(0)).uniform_(*self.jpeg_range) | |
| img_lq = torch.clamp(img_lq, 0, 1) | |
| original_device = img_lq.device | |
| img_lq = self.jpeger(img_lq.cpu(), quality=jpeg_p.cpu()).to(original_device) | |
| # ----------------------- The second degradation process (50% probability) ----------------------- # | |
| if random.random() < self.second_order_prob: | |
| # 1. BLUR (Second Pass) | |
| if random.random() < self.second_blur_prob: | |
| # Generate second kernel | |
| kernel_size2 = random.choice(self.kernel_range2) | |
| if random.random() < self.sinc_prob2: | |
| # Sinc kernel for second degradation | |
| if kernel_size2 < 13: | |
| omega_c = random.uniform(math.pi / 3, math.pi) | |
| else: | |
| omega_c = random.uniform(math.pi / 5, math.pi) | |
| kernel2 = circular_lowpass_kernel(omega_c, kernel_size2, pad_to=False) | |
| else: | |
| kernel2 = random_mixed_kernels( | |
| self.kernel_list, | |
| self.kernel_prob, | |
| kernel_size2, | |
| self.blur_sigma2, | |
| self.blur_sigma2, | |
| [-math.pi, math.pi], | |
| self.betag_range2, | |
| self.betap_range2, | |
| noise_range=None | |
| ) | |
| # Pad kernel | |
| pad_size = (self.blur_kernel_size2 - kernel_size2) // 2 | |
| kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) | |
| if isinstance(kernel2, np.ndarray): | |
| kernel2 = torch.FloatTensor(kernel2).to(device) | |
| img_lq = filter2D(img_lq, kernel2) | |
| # 2. RANDOM RESIZE (Second degradation) | |
| updown_type = random.choices(['up', 'down', 'keep'], weights=self.resize_prob2)[0] | |
| if updown_type == 'up': | |
| scale_factor = random.uniform(1, self.resize_range2[1]) | |
| elif updown_type == 'down': | |
| scale_factor = random.uniform(self.resize_range2[0], 1) | |
| else: | |
| scale_factor = 1 | |
| if scale_factor != 1: | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| img_lq = F_torch.interpolate( | |
| img_lq, | |
| size=(int(ori_h / self.scale * scale_factor), int(ori_w / self.scale * scale_factor)), | |
| mode=mode | |
| ) | |
| # 3. NOISE (Second Pass) | |
| if random.random() < self.gaussian_noise_prob2: | |
| img_lq = random_add_gaussian_noise_pt( | |
| img_lq, | |
| sigma_range=self.noise_range2, | |
| clip=True, | |
| rounds=False, | |
| gray_prob=self.gray_noise_prob2 | |
| ) | |
| else: | |
| img_lq = random_add_poisson_noise_pt( | |
| img_lq, | |
| scale_range=self.poisson_scale_range2, | |
| gray_prob=self.gray_noise_prob2, | |
| clip=True, | |
| rounds=False | |
| ) | |
| # ----------------------- Final stage: Resize back + Sinc filter + JPEG ----------------------- # | |
| # Generate final sinc kernel | |
| if random.random() < self.final_sinc_prob: | |
| kernel_size = random.choice(self.kernel_range2) | |
| omega_c = random.uniform(math.pi / 3, math.pi) | |
| sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=self.blur_kernel_size2) | |
| sinc_kernel = torch.FloatTensor(sinc_kernel).to(device) | |
| else: | |
| sinc_kernel = self.pulse_tensor.to(device) # Identity (no sinc filter) | |
| # Randomize order: [resize + sinc] + JPEG vs JPEG + [resize + sinc] | |
| if random.random() < 0.5: | |
| # Order 1: Resize back + sinc filter, then JPEG | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| img_lq = F_torch.interpolate( | |
| img_lq, | |
| size=(ori_h // self.scale, ori_w // self.scale), | |
| mode=mode | |
| ) | |
| img_lq = filter2D(img_lq, sinc_kernel) | |
| # JPEG compression | |
| jpeg_p = img_lq.new_zeros(img_lq.size(0)).uniform_(*self.jpeg_range2) | |
| img_lq = torch.clamp(img_lq, 0, 1) | |
| original_device = img_lq.device | |
| img_lq = self.jpeger(img_lq.cpu(), quality=jpeg_p.cpu()).to(original_device) | |
| else: | |
| # Order 2: JPEG compression, then resize back + sinc filter | |
| jpeg_p = img_lq.new_zeros(img_lq.size(0)).uniform_(*self.jpeg_range2) | |
| img_lq = torch.clamp(img_lq, 0, 1) | |
| original_device = img_lq.device | |
| img_lq = self.jpeger(img_lq.cpu(), quality=jpeg_p.cpu()).to(original_device) | |
| # Resize back + sinc filter | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| img_lq = F_torch.interpolate( | |
| img_lq, | |
| size=(ori_h // self.scale, ori_w // self.scale), | |
| mode=mode | |
| ) | |
| img_lq = filter2D(img_lq, sinc_kernel) | |
| # Clamp and round (final step) | |
| img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.0 | |
| return img_lq.squeeze(0) # Squeeze batch dim | |
| def process_dataset(hr_folder, output_base_dir, dataset_name, scale=4, patch_size=256, device='cpu'): | |
| """ | |
| Process a dataset (train or valid) and generate patches. | |
| Args: | |
| hr_folder: Path to HR images folder | |
| output_base_dir: Base directory for output | |
| dataset_name: 'train' or 'valid' | |
| scale: Upscaling factor (default: 4) | |
| patch_size: Size of patches to extract (default: 256) | |
| device: Device to use ('cpu' or 'cuda') | |
| """ | |
| # Create output folders | |
| hr_patches_folder = os.path.join(output_base_dir, f'DIV2K_{dataset_name}_HR_patches_256x256') | |
| lr_patches_folder = os.path.join(output_base_dir, f'DIV2K_{dataset_name}_LR_patches_256x256_upsampled') | |
| os.makedirs(hr_patches_folder, exist_ok=True) | |
| os.makedirs(lr_patches_folder, exist_ok=True) | |
| print(f"\n{'='*60}") | |
| print(f"Processing {dataset_name.upper()} dataset") | |
| print(f"{'='*60}") | |
| print(f"Using device: {device}") | |
| print(f"HR folder: {hr_folder}") | |
| print(f"Output folders:") | |
| print(f" - HR patches: {hr_patches_folder}") | |
| print(f" - LR patches: {lr_patches_folder}\n") | |
| # Initialize degradation pipeline | |
| print("Initializing Real-ESRGAN degradation pipeline (SMOOTH BLUR)...") | |
| degrader = RealESRGANDegrader(scale=scale) | |
| # Don't move jpeger to device - it will handle device placement internally | |
| print("Pipeline ready!\n") | |
| # Get image paths | |
| hr_image_paths = sorted(glob.glob(os.path.join(hr_folder, '*.png'))) | |
| if not hr_image_paths: | |
| print(f"ERROR: No images found in {hr_folder}") | |
| return 0 | |
| print(f"Found {len(hr_image_paths)} images") | |
| print(f"Processing entire images on {str(device).upper()}") | |
| print(f"Extracting {patch_size}x{patch_size} patches after degradation") | |
| print(f"Upsampling LR patches back to {patch_size}x{patch_size}\n") | |
| patch_count = 0 | |
| upsample_layer = torch.nn.Upsample(scale_factor=scale, mode='nearest').to(device) | |
| # Process each HR image | |
| for img_idx, img_path in enumerate(tqdm(hr_image_paths, desc=f"Processing {dataset_name} images")): | |
| try: | |
| # Load HR image | |
| img_hr_full = cv2.imread(img_path, cv2.IMREAD_COLOR) | |
| if img_hr_full is None: | |
| print(f"Warning: Could not load {img_path}, skipping...") | |
| continue | |
| img_hr_full = img_hr_full.astype(np.float32) / 255.0 | |
| img_hr_full = cv2.cvtColor(img_hr_full, cv2.COLOR_BGR2RGB) | |
| # Validate image values | |
| if np.any(np.isnan(img_hr_full)) or np.any(np.isinf(img_hr_full)): | |
| print(f"Warning: Invalid values in {img_path}, skipping...") | |
| continue | |
| # Ensure values are in valid range [0, 1] | |
| img_hr_full = np.clip(img_hr_full, 0.0, 1.0) | |
| h, w = img_hr_full.shape[:2] | |
| # Check image dimensions | |
| if h < patch_size or w < patch_size: | |
| print(f"Warning: Image {img_path} too small ({h}x{w}), skipping...") | |
| continue | |
| # Convert entire HR image to tensor and move to device | |
| hr_tensor_full = torch.from_numpy(np.transpose(img_hr_full, (2, 0, 1))).float().to(device) # [C, H, W] | |
| # Validate tensor before processing | |
| if torch.any(torch.isnan(hr_tensor_full)) or torch.any(torch.isinf(hr_tensor_full)): | |
| print(f"Warning: Invalid tensor values in {img_path}, skipping...") | |
| continue | |
| # Apply Real-ESRGAN degradation to entire image | |
| with torch.no_grad(): | |
| lr_tensor_full = degrader.degrade(hr_tensor_full) # [C, H//4, W//4] | |
| # Validate degraded tensor | |
| if torch.any(torch.isnan(lr_tensor_full)) or torch.any(torch.isinf(lr_tensor_full)): | |
| print(f"Warning: Degradation produced invalid values for {img_path}, skipping...") | |
| continue | |
| # Upsample entire LR image back to HR size | |
| lr_tensor_upsampled = upsample_layer(lr_tensor_full.unsqueeze(0)).squeeze(0) # [C, H, W] | |
| # Validate upsampled tensor | |
| if torch.any(torch.isnan(lr_tensor_upsampled)) or torch.any(torch.isinf(lr_tensor_upsampled)): | |
| print(f"Warning: Upsampling produced invalid values for {img_path}, skipping...") | |
| continue | |
| # Move back to CPU for patch extraction | |
| hr_full_cpu = hr_tensor_full.cpu().numpy() | |
| lr_full_cpu = lr_tensor_upsampled.cpu().numpy() | |
| # Extract non-overlapping patches | |
| num_patches_h = h // patch_size | |
| num_patches_w = w // patch_size | |
| # Prepare batch of patches for saving | |
| hr_patches_to_save = [] | |
| lr_patches_to_save = [] | |
| patch_names = [] | |
| for i in range(num_patches_h): | |
| for j in range(num_patches_w): | |
| # Extract patch coordinates | |
| y_start = i * patch_size | |
| x_start = j * patch_size | |
| y_end = y_start + patch_size | |
| x_end = x_start + patch_size | |
| # Extract patches from numpy arrays [C, H, W] -> [H, W, C] | |
| hr_patch_np = np.transpose(hr_full_cpu[:, y_start:y_end, x_start:x_end], (1, 2, 0)) | |
| lr_patch_np = np.transpose(lr_full_cpu[:, y_start:y_end, x_start:x_end], (1, 2, 0)) | |
| # Clip and convert to uint8 | |
| hr_patch_np = np.clip(hr_patch_np * 255.0, 0, 255).astype(np.uint8) | |
| lr_patch_np = np.clip(lr_patch_np * 255.0, 0, 255).astype(np.uint8) | |
| # Convert RGB to BGR for OpenCV | |
| hr_patch_bgr = cv2.cvtColor(hr_patch_np, cv2.COLOR_RGB2BGR) | |
| lr_patch_bgr = cv2.cvtColor(lr_patch_np, cv2.COLOR_RGB2BGR) | |
| # Store for batch saving | |
| hr_patches_to_save.append(hr_patch_bgr) | |
| lr_patches_to_save.append(lr_patch_bgr) | |
| basename = os.path.splitext(os.path.basename(img_path))[0] | |
| patch_names.append(f"{basename}_patch_{i}_{j}.png") | |
| # Batch save all patches for this image | |
| for idx, patch_name in enumerate(patch_names): | |
| hr_patch_path = os.path.join(hr_patches_folder, patch_name) | |
| lr_patch_path = os.path.join(lr_patches_folder, patch_name) | |
| cv2.imwrite(hr_patch_path, hr_patches_to_save[idx]) | |
| cv2.imwrite(lr_patch_path, lr_patches_to_save[idx]) | |
| patch_count += 1 | |
| except Exception as e: | |
| print(f"\nError processing {img_path}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| print(f"\n{dataset_name.upper()} Dataset Complete!") | |
| print(f" - Processed {len(hr_image_paths)} images") | |
| print(f" - Generated {patch_count} patch pairs") | |
| print(f" - HR patches: {hr_patches_folder}") | |
| print(f" - LR patches: {lr_patches_folder}\n") | |
| return patch_count | |
| def main(): | |
| """Main function to process both training and validation datasets""" | |
| # Configuration - use paths from config | |
| from config import _project_root, scale, patch_size | |
| # Force CPU usage | |
| device = torch.device("cpu") | |
| print("="*60) | |
| print("DiffusionSR Patch Generation (CPU Mode)") | |
| print("="*60) | |
| print(f"Using device: {device}") | |
| print(f"Scale factor: {scale}x") | |
| print(f"Patch size: {patch_size}x{patch_size}\n") | |
| # Dataset paths | |
| data_dir = os.path.join(_project_root, 'data') | |
| train_hr_folder = os.path.join(data_dir, 'DIV2K_train_HR') | |
| valid_hr_folder = os.path.join(data_dir, 'DIV2K_valid_HR') | |
| # Output base directory | |
| output_base_dir = data_dir | |
| total_train_patches = 0 | |
| total_valid_patches = 0 | |
| # Process training dataset | |
| if os.path.exists(train_hr_folder): | |
| total_train_patches = process_dataset( | |
| hr_folder=train_hr_folder, | |
| output_base_dir=output_base_dir, | |
| dataset_name='train', | |
| scale=scale, | |
| patch_size=patch_size, | |
| device=device | |
| ) | |
| else: | |
| print(f"WARNING: Training folder not found: {train_hr_folder}\n") | |
| # Process validation dataset | |
| if os.path.exists(valid_hr_folder): | |
| total_valid_patches = process_dataset( | |
| hr_folder=valid_hr_folder, | |
| output_base_dir=output_base_dir, | |
| dataset_name='valid', | |
| scale=scale, | |
| patch_size=patch_size, | |
| device=device | |
| ) | |
| else: | |
| print(f"WARNING: Validation folder not found: {valid_hr_folder}\n") | |
| # Summary | |
| print("="*60) | |
| print("GENERATION COMPLETE!") | |
| print("="*60) | |
| print(f"Training patches: {total_train_patches:,}") | |
| print(f"Validation patches: {total_valid_patches:,}") | |
| print(f"Total patches: {total_train_patches + total_valid_patches:,}") | |
| # Display sample patches from training set | |
| train_hr_patches_folder = os.path.join(output_base_dir, 'DIV2K_train_HR_patches_256x256') | |
| train_lr_patches_folder = os.path.join(output_base_dir, 'DIV2K_train_LR_patches_256x256_upsampled') | |
| sample_patches = sorted(glob.glob(os.path.join(train_hr_patches_folder, '*.png')))[:5] | |
| if sample_patches: | |
| print("\nDisplaying sample patches from training set...") | |
| fig, axes = plt.subplots(len(sample_patches), 2, figsize=(10, len(sample_patches) * 2)) | |
| if len(sample_patches) == 1: | |
| axes = np.array([axes]) | |
| for i, hr_patch_path in enumerate(sample_patches): | |
| basename = os.path.basename(hr_patch_path) | |
| lr_patch_path = os.path.join(train_lr_patches_folder, basename) | |
| if os.path.exists(lr_patch_path): | |
| hr = cv2.imread(hr_patch_path) | |
| lr = cv2.imread(lr_patch_path) | |
| hr_rgb = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB) | |
| lr_rgb = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB) | |
| axes[i, 0].imshow(hr_rgb) | |
| axes[i, 0].set_title(f"HR Patch: {basename}", fontweight='bold') | |
| axes[i, 0].axis('off') | |
| axes[i, 1].imshow(lr_rgb) | |
| axes[i, 1].set_title(f"LR Patch (upsampled): {basename}", fontweight='bold') | |
| axes[i, 1].axis('off') | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(_project_root, 'data', 'sample_patches.png'), dpi=150, bbox_inches='tight') | |
| print(f"Sample visualization saved to: {os.path.join(_project_root, 'data', 'sample_patches.png')}") | |
| print("\nDone! Dataset generation complete.") | |
| print(f"\nNext steps:") | |
| print(f" 1. Update config.py:") | |
| print(f" - Set dir_HR = '{train_hr_patches_folder}'") | |
| print(f" - Set dir_LR = '{train_lr_patches_folder}'") | |
| print(f" 2. The SRDataset will now:") | |
| print(f" - Load pre-generated 256x256 HR patches") | |
| print(f" - Load pre-generated 256x256 upsampled LR patches") | |
| print(f" - Skip cropping (patches are already the right size)") | |
| print(f" - Apply augmentations (flip, rotate)") | |
| print(f" 3. Training will use these patches directly (no upsampling needed)") | |
| if __name__ == "__main__": | |
| main() |