diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..14f3918cb71bf7326daa2ad887710eb42afbfaea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.bmp filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text +*.ipynb filter=lfs diff=lfs merge=lfs -text + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..387a37c2fc21121c5f626567e07cb2323aebfbde --- /dev/null +++ b/.gitignore @@ -0,0 +1,78 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Jupyter Notebook +.ipynb_checkpoints + +# PyCharm +.idea/ + +# VS Code +.vscode/ + +# Environment +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +*.yml +!environment.yml + + + +# Data +data/ +datasets/ +*.zip +*.tar +*.gz + +# Logs +logs/ +*.log +tensorboard/ + +# Images and media +*.jpg +*.jpeg +*.png +*.gif +*.mp4 +!demo/**/*.jpg +!demo/**/*.png +!demo/**/*.jpeg + +# Output directories +outputs/ +results/ +predictions/ + +# OS specific +.DS_Store +Thumbs.db diff --git a/DeblurGanV2/README.md b/DeblurGanV2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5896a636f05b0106576b45c8146214e549e81392 --- /dev/null +++ b/DeblurGanV2/README.md @@ -0,0 +1,123 @@ +# Debluring Application + +This application provides tools for deblurring images using deep learning. The system consists of: + +1. A **DeblurGAN module** for image deblurring using deep learning +2. A **FastAPI backend** for processing images via API endpoints +3. A **Streamlit frontend** for a user-friendly web interface + +## Features + +- Upload and deblur images with a single click +- Modern web interface with side-by-side comparison of original and deblurred images +- Download deblurred results +- GPU-accelerated processing (with fallback to CPU) +- Command-line interface for batch processing + +## Requirements + +- Python 3.7+ +- PyTorch +- CUDA (optional, for GPU acceleration) +- Other dependencies listed in requirements.txt + +## Setup Instructions + +1. Clone this repository: + +``` +git clone +cd Debluring +``` + +2. Install dependencies: + +``` +pip install -r requirements.txt +``` + +**Note**: This application requires the `pretrainedmodels` package for the deblurring model. If you encounter any issues, you can manually install the required packages: + +``` +pip install pretrainedmodels torchsummary albumentations opencv-python-headless +``` + +## Running the Application + +You can run the application in three different ways: + +### 1. Standalone DeblurGAN Module + +Process all images in the `inputs/` directory and save results to `outputs/`: + +``` +python deblur_module.py +``` + +### 2. API Server + +Start the FastAPI server: + +``` +python api.py +``` + +The API will be available at: http://localhost:8001 + +- POST `/deblur/` - Upload an image for deblurring +- GET `/status/` - Check API status +- GET `/diagnostics/` - Get system diagnostics + +### 3. Web Interface + +First, ensure the API server is running, then start the Streamlit interface: + +``` +python app.py +``` + +The web UI will be available at: http://localhost:8501 + +### 4. All-in-One (Optional) + +Use the run.py script to start both the API server and the Streamlit interface: + +``` +python run.py +``` + +## Using the Application + +1. **Command-line mode**: + + - Place your blurry images in the `inputs/` directory + - Run `python deblur_module.py` + - Find the deblurred results in the `outputs/` directory + +2. **Web interface mode**: + - Start both the API and the Streamlit app + - Upload an image through the web interface + - Click "Deblur Image" and wait for processing + - Download the deblurred image when ready + +## Troubleshooting + +If you encounter any issues: + +1. Check the API diagnostics endpoint: http://localhost:8001/diagnostics/ + +2. Review logs in the `logs/` directory for detailed error information. + +3. Make sure your system has sufficient resources for running the deblurring model. + +4. If you're experiencing CUDA errors, the application will automatically fall back to CPU processing. + +5. For large images, the application will automatically resize them to a maximum dimension of 2000 pixels to avoid memory issues. + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## Acknowledgements + +This application is based on the DeblurGAN architecture for image deblurring with modifications to create a user-friendly interface and robust error handling. diff --git a/DeblurGanV2/api.py b/DeblurGanV2/api.py new file mode 100644 index 0000000000000000000000000000000000000000..1113f6f36abd2f6544744b0e249e700f38e30794 --- /dev/null +++ b/DeblurGanV2/api.py @@ -0,0 +1,320 @@ +import os +import uuid +import gc +from typing import Optional +from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks +from fastapi.responses import FileResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import logging +from logging_utils import setup_logger +import traceback +import shutil +import torch +import requests + +from deblur_module import DeblurGAN + +# Configure logging +logger = setup_logger(__name__) + +# Define API URL +API_URL = "http://localhost:8001" + +# Initialize FastAPI app +app = FastAPI( + title="Debluring Application API", + description="API for deblurring images using deep learning", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize DeblurGAN model +model = None +processing_lock = False + +def get_model(): + global model + if model is None: + logger.info("Initializing deblurring model...") + try: + model = DeblurGAN() + logger.info("Model initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize model: {str(e)}") + logger.error(traceback.format_exc()) + raise RuntimeError(f"Could not initialize DeblurGAN model: {str(e)}") + return model + +def cleanup_resources(input_path=None): + """Clean up resources after processing""" + # Force garbage collection + gc.collect() + + # Clean up CUDA memory if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Remove input file if specified + if input_path and os.path.exists(input_path): + try: + os.remove(input_path) + logger.info(f"Removed temporary input file: {input_path}") + except Exception as e: + logger.warning(f"Could not remove temporary input file: {str(e)}") + + # Release processing lock + global processing_lock + processing_lock = False + logger.info("Resources cleaned up") + +@app.get("/") +async def root(): + return {"message": "Debluring Application API is running"} + +@app.post("/deblur/", response_class=FileResponse) +async def deblur_image(background_tasks: BackgroundTasks, file: UploadFile = File(...)): + """ + Deblur an uploaded image and return the processed image file. + """ + global processing_lock + input_path = None + + # Check if another processing request is currently running + if processing_lock: + logger.warning("Another deblurring request is already in progress") + raise HTTPException( + status_code=429, + detail="Server is busy processing another image. Please try again shortly." + ) + + # Set processing lock + processing_lock = True + + try: + # Validate file type + if not file.content_type.startswith("image/"): + logger.warning(f"Invalid file type: {file.content_type}") + processing_lock = False # Release lock + raise HTTPException(status_code=400, detail="File must be an image") + + logger.info(f"Processing image: {file.filename}, size: {file.size} bytes, type: {file.content_type}") + + # Create input and output directories + module_dir = os.path.dirname(os.path.abspath(__file__)) + input_dir = os.path.join(module_dir, 'inputs') + output_dir = os.path.join(module_dir, 'outputs') + os.makedirs(input_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + # Generate unique filenames + unique_id = uuid.uuid4().hex + input_filename = f"input_{unique_id}.png" + input_path = os.path.join(input_dir, input_filename) + output_filename = f"deblurred_{unique_id}.png" + output_path = os.path.join(output_dir, output_filename) + + try: + # Read file contents first + file_contents = await file.read() + + # Save uploaded file to disk + with open(input_path, "wb") as buffer: + buffer.write(file_contents) + + logger.info(f"Input image saved to: {input_path}") + + # Release file resources immediately + await file.close() + + # Get the model + deblur_model = get_model() + + # Process the image + logger.info("Starting deblurring process...") + deblurred_img = deblur_model.deblur_image(input_path) + + # Save the result + deblur_model.save_image(deblurred_img, output_filename) + + logger.info(f"Image deblurred successfully, saved to: {output_path}") + + # Schedule cleanup after response is sent + background_tasks.add_task(cleanup_resources, input_path) + + # Return the result file + return FileResponse( + output_path, + media_type="image/png", + filename=f"deblurred_{file.filename}" + ) + except Exception as e: + logger.error(f"Error in deblurring process: {str(e)}") + logger.error(traceback.format_exc()) + # Always attempt cleanup on error + cleanup_resources(input_path) + raise HTTPException(status_code=500, detail=f"Deblurring failed: {str(e)}") + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + error_msg = f"Error processing image: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + # Make sure lock is released + processing_lock = False + raise HTTPException(status_code=500, detail=error_msg) + +@app.get("/status/") +async def status(): + """Check API status and model availability.""" + try: + logger.info("Checking model status") + + # Check if we're currently processing + if processing_lock: + return { + "status": "busy", + "model_loaded": True, + "message": "Currently processing an image" + } + + # Otherwise do a full check + deblur_model = get_model() + + # Get memory stats if CUDA is available + memory_info = {} + if torch.cuda.is_available(): + memory_info["cuda_memory_allocated"] = f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB" + memory_info["cuda_memory_reserved"] = f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB" + memory_info["cuda_max_memory"] = f"{torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" + + logger.info("Model is loaded and ready") + return { + "status": "ok", + "model_loaded": True, + "processing": processing_lock, + "memory": memory_info + } + except Exception as e: + error_msg = f"Error checking model status: {str(e)}" + logger.error(error_msg) + return {"status": "error", "model_loaded": False, "error": str(e)} + +@app.get("/clear-memory/") +async def clear_memory(): + """Force clear memory and release resources.""" + try: + # Force garbage collection + gc.collect() + + # Clear CUDA cache if available + if torch.cuda.is_available(): + before = torch.cuda.memory_allocated() / 1024**2 + torch.cuda.empty_cache() + after = torch.cuda.memory_allocated() / 1024**2 + logger.info(f"CUDA memory cleared: {before:.2f} MB → {after:.2f} MB") + + # Reset processing lock + global processing_lock + was_locked = processing_lock + processing_lock = False + + return { + "status": "ok", + "message": "Memory cleared successfully", + "lock_released": was_locked + } + except Exception as e: + error_msg = f"Error clearing memory: {str(e)}" + logger.error(error_msg) + return {"status": "error", "error": str(e)} + +@app.get("/diagnostics/") +async def diagnostics(): + """Get diagnostic information about the system.""" + try: + # Check required components + import platform + import sys + import torch + import cv2 + import psutil + + # Create diagnostics report + diagnostic_info = { + "system": { + "platform": platform.platform(), + "python_version": sys.version, + "processor": platform.processor() + }, + "memory": { + "total": f"{psutil.virtual_memory().total / (1024**3):.2f} GB", + "available": f"{psutil.virtual_memory().available / (1024**3):.2f} GB", + "used_percent": f"{psutil.virtual_memory().percent}%" + }, + "torch": { + "version": torch.__version__, + "cuda_available": torch.cuda.is_available(), + "cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0, + "cuda_version": torch.version.cuda if torch.cuda.is_available() else None + }, + "opencv": { + "version": cv2.__version__ + }, + "dirs": { + "inputs_exists": os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'inputs')), + "outputs_exists": os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'outputs')) + }, + "api_state": { + "processing_lock": processing_lock + } + } + + # Add CUDA memory info if available + if torch.cuda.is_available(): + diagnostic_info["cuda_memory"] = { + "allocated": f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB", + "reserved": f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB", + "max_allocated": f"{torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" + } + + # Try to check model-specific info + try: + from models.networks import INCEPTION_AVAILABLE + diagnostic_info["model"] = { + "inception_available": INCEPTION_AVAILABLE + } + except ImportError: + diagnostic_info["model"] = { + "inception_available": False, + "error": "Could not import networks module" + } + + return diagnostic_info + except Exception as e: + error_msg = f"Error in diagnostics: {str(e)}" + logger.error(error_msg) + return {"status": "error", "error": str(e)} + +def check_api_health(timeout=2): + """Check if the API is responsive""" + try: + response = requests.get(f"{API_URL}/status", timeout=timeout) + return response.status_code == 200 + except: + return False + +if __name__ == "__main__": + # Run the FastAPI server + logger.info("Starting Debluring Application API server...") + uvicorn.run("api:app", host="0.0.0.0", port=8001, reload=True) \ No newline at end of file diff --git a/DeblurGanV2/app.py b/DeblurGanV2/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2040cff52ec824a3a86b63245990999d5f4a9bc6 --- /dev/null +++ b/DeblurGanV2/app.py @@ -0,0 +1,97 @@ +import streamlit as st +import requests +import os +import sys +from PIL import Image +import io +import time +from pathlib import Path + +# Set API URL +API_URL = "http://localhost:8001" # Local FastAPI server URL + +st.set_page_config( + page_title="Image Deblurring App", + page_icon="🔍", + layout="wide", +) + +st.title("Image Deblurring Application") +st.markdown(""" +Turn your blurry photos into clear, sharp images using AI technology. +Upload an image to get started! +""") + +# File uploader +uploaded_file = st.file_uploader( + "Choose a blurry image...", type=["jpg", "jpeg", "png", "bmp"]) + +# Sidebar controls +with st.sidebar: + st.header("Deblurring Options") + + # Additional options could be added here in the future + st.markdown("This application uses a deep learning model to remove blur from images.") + st.markdown("---") + + # Check API status + if st.button("Check API Status"): + try: + response = requests.get(f"{API_URL}/status/", timeout=5) + if response.status_code == 200 and response.json().get("status") == "ok": + st.success("✅ API is running and ready") + else: + st.error("❌ API is not responding properly") + except: + st.error("❌ Cannot connect to API") + +# Process when upload is ready +if uploaded_file is not None: + # Display the original image + col1, col2 = st.columns(2) + + with col1: + st.subheader("Original Image") + image = Image.open(uploaded_file) + st.image(image, use_column_width=True) + + # Process image button + process_button = st.button("Deblur Image") + + if process_button: + with st.spinner("Deblurring your image... Please wait."): + try: + # Prepare simplified file structure + files = { + "file": ("image.jpg", uploaded_file.getvalue(), "image/jpeg") + } + + # Send request to API + response = requests.post(f"{API_URL}/deblur/", files=files, timeout=60) + + if response.status_code == 200: + with col2: + st.subheader("Deblurred Result") + deblurred_img = Image.open(io.BytesIO(response.content)) + st.image(deblurred_img, use_column_width=True) + + # Option to download the deblurred image + st.download_button( + label="Download Deblurred Image", + data=response.content, + file_name=f"deblurred_{uploaded_file.name}", + mime="image/png" + ) + else: + try: + error_details = response.json().get('detail', 'Unknown error') + except: + error_details = response.text + + st.error(f"Error: {error_details}") + except Exception as e: + st.error(f"An error occurred: {str(e)}") + +# Footer +st.markdown("---") +st.markdown("Powered by DeblurGAN - Image Restoration Project") \ No newline at end of file diff --git a/DeblurGanV2/aug.py b/DeblurGanV2/aug.py new file mode 100644 index 0000000000000000000000000000000000000000..d34b947cb0b28f80819a44bcaa567b51969fb9e7 --- /dev/null +++ b/DeblurGanV2/aug.py @@ -0,0 +1,77 @@ +from typing import List + +import albumentations as albu + + +def get_transforms(size: int, scope: str = 'geometric', crop='random'): + augs = {'weak': albu.Compose([albu.HorizontalFlip(), + ]), + 'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True), + albu.ShiftScaleRotate(always_apply=True), + albu.Transpose(always_apply=True), + albu.OpticalDistortion(always_apply=True), + albu.ElasticTransform(always_apply=True), + ]) + } + + aug_fn = augs[scope] + crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True), + 'center': albu.CenterCrop(size, size, always_apply=True)}[crop] + pad = albu.PadIfNeeded(size, size) + + pipeline = albu.Compose([aug_fn, pad, crop_fn], additional_targets={'target': 'image'}) + + def process(a, b): + r = pipeline(image=a, target=b) + return r['image'], r['target'] + + return process + + +def get_normalize(): + normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + normalize = albu.Compose([normalize], additional_targets={'target': 'image'}) + + def process(a, b): + r = normalize(image=a, target=b) + return r['image'], r['target'] + + return process + + +def _resolve_aug_fn(name): + d = { + 'cutout': albu.Cutout, + 'rgb_shift': albu.RGBShift, + 'hsv_shift': albu.HueSaturationValue, + 'motion_blur': albu.MotionBlur, + 'median_blur': albu.MedianBlur, + 'snow': albu.RandomSnow, + 'shadow': albu.RandomShadow, + 'fog': albu.RandomFog, + 'brightness_contrast': albu.RandomBrightnessContrast, + 'gamma': albu.RandomGamma, + 'sun_flare': albu.RandomSunFlare, + 'sharpen': albu.Sharpen, + 'jpeg': albu.ImageCompression, + 'gray': albu.ToGray, + 'pixelize': albu.Downscale, + # ToDo: partial gray + } + return d[name] + + +def get_corrupt_function(config: List[dict]): + augs = [] + for aug_params in config: + name = aug_params.pop('name') + cls = _resolve_aug_fn(name) + prob = aug_params.pop('prob') if 'prob' in aug_params else .5 + augs.append(cls(p=prob, **aug_params)) + + augs = albu.OneOf(augs) + + def process(x): + return augs(image=x)['image'] + + return process diff --git a/DeblurGanV2/config/config.yaml b/DeblurGanV2/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..793cdbd47aaec3381696b5d2dd9ffc5ffa7e4805 --- /dev/null +++ b/DeblurGanV2/config/config.yaml @@ -0,0 +1,13 @@ +--- +project: debluring_application + +# Model configuration +model: + g_name: fpn_inception + blocks: 9 + learn_residual: True + norm_layer: instance + dropout: True + +# Inference settings +image_size: [256, 256] diff --git a/DeblurGanV2/deblur_module.py b/DeblurGanV2/deblur_module.py new file mode 100644 index 0000000000000000000000000000000000000000..8fec13dbb25027dab34fd2f392c6311e8506890e --- /dev/null +++ b/DeblurGanV2/deblur_module.py @@ -0,0 +1,285 @@ +import os +import cv2 +import numpy as np +import torch +import yaml +from typing import Optional, Tuple, Union +from io import BytesIO +from PIL import Image +import logging +import traceback + +from aug import get_normalize +from models.networks import get_generator +from logging_utils import setup_logger + +# Configure logging +logger = setup_logger(__name__) + +class DeblurGAN: + def __init__(self, weights_path: str = 'fpn_inception.h5', model_name: str = ''): + """ + Initialize the DeblurGAN model. + + Args: + weights_path: Path to model weights file + model_name: Name of the model architecture (if empty, will be read from config) + """ + try: + logger.info(f"Initializing DeblurGAN with weights: {weights_path}") + # Make paths relative to the module directory + module_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(module_dir, 'config/config.yaml') + if not os.path.isabs(weights_path): + weights_path = os.path.join(module_dir, weights_path) + + # Check if weights file exists + if not os.path.exists(weights_path): + error_msg = f"Weights file not found: {weights_path}" + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # Load configuration + logger.info(f"Loading configuration from {config_path}") + with open(config_path, encoding='utf-8') as cfg: + config = yaml.load(cfg, Loader=yaml.FullLoader) + + # Initialize model + logger.info(f"Creating model with architecture: {model_name or config['model']['g_name']}") + model = get_generator(model_name or config['model']['g_name']) + + logger.info("Loading model weights") + model.load_state_dict(torch.load(weights_path)['model']) + + # Try CUDA first, fall back to CPU if necessary + try: + self.model = model.cuda() + self.device = 'cuda' + logger.info("Model moved to CUDA successfully") + except Exception as e: + logger.warning(f"Failed to move model to CUDA. Error: {str(e)}") + logger.warning("Using CPU mode") + self.model = model + self.device = 'cpu' + + self.model.train(True) # GAN inference uses train mode for batch norm stats + self.normalize_fn = get_normalize() + + # Create directories for inputs and outputs + module_dir = os.path.dirname(os.path.abspath(__file__)) + self.inputs_dir = os.path.join(module_dir, 'inputs') + self.outputs_dir = os.path.join(module_dir, 'outputs') + + # Ensure directories exist + os.makedirs(self.inputs_dir, exist_ok=True) + os.makedirs(self.outputs_dir, exist_ok=True) + + logger.info("Model initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize model: {str(e)}") + logger.error(traceback.format_exc()) + raise + + @staticmethod + def _array_to_batch(x): + """Convert numpy array to batch tensor""" + x = np.transpose(x, (2, 0, 1)) + x = np.expand_dims(x, 0) + return torch.from_numpy(x) + + def _preprocess(self, x: np.ndarray) -> Tuple: + """Preprocess the input image for the model.""" + # Normalize + x, _ = self.normalize_fn(x, x) + mask = np.ones_like(x, dtype=np.float32) + + # Pad to be divisible by block_size + h, w, _ = x.shape + block_size = 32 + min_height = (h // block_size + 1) * block_size + min_width = (w // block_size + 1) * block_size + + pad_params = { + 'mode': 'constant', + 'constant_values': 0, + 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0)) + } + x = np.pad(x, **pad_params) + mask = np.pad(mask, **pad_params) + + return map(self._array_to_batch, (x, mask)), h, w + + @staticmethod + def _postprocess(x: torch.Tensor) -> np.ndarray: + """Convert the model output tensor to a numpy array.""" + x, = x + x = x.detach().cpu().float().numpy() + x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0 + return x.astype('uint8') + + def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: + """ + Deblur an image. + + Args: + image: Input image as a file path, numpy array, or bytes + + Returns: + Deblurred image as a numpy array + """ + try: + # Handle different input types + if isinstance(image, str): + # Image path + logger.info(f"Loading image from path: {image}") + img = cv2.imread(image) + if img is None: + raise ValueError(f"Failed to read image from {image}") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + elif isinstance(image, bytes): + # Bytes (e.g., from file upload) + logger.info("Loading image from bytes") + nparr = np.frombuffer(image, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is None: + # Try using PIL as a fallback + pil_img = Image.open(BytesIO(image)) + img = np.array(pil_img.convert('RGB')) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + elif isinstance(image, np.ndarray): + # Already a numpy array + logger.info("Processing image from numpy array") + img = image.copy() + if img.shape[2] == 3 and img.dtype == np.uint8: + if img[0,0,0] > img[0,0,2]: # Simple BGR check + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + # Validate image + if img is None or img.size == 0: + raise ValueError("Image is empty or invalid") + + logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") + + # Ensure image has 3 channels + if len(img.shape) != 3 or img.shape[2] != 3: + raise ValueError(f"Image must have 3 channels, got shape {img.shape}") + + # Resize very large images + max_dim = max(img.shape[0], img.shape[1]) + if max_dim > 2000: + scale_factor = 2000 / max_dim + new_h = int(img.shape[0] * scale_factor) + new_w = int(img.shape[1] * scale_factor) + logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") + img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + + # Process the image + logger.info("Preprocessing image") + (img_batch, mask_batch), h, w = self._preprocess(img) + + logger.info("Running inference with model") + with torch.no_grad(): + try: + # Try to use the device that was set during initialization + inputs = [img_batch.to(self.device)] + pred = self.model(*inputs) + except Exception as e: + # If device fails, fall back to CPU + logger.warning(f"Error using {self.device}: {str(e)}. Falling back to CPU.") + if self.device == 'cuda': + torch.cuda.empty_cache() # Free GPU memory + inputs = [img_batch.to('cpu')] + self.model = self.model.to('cpu') + self.device = 'cpu' + pred = self.model(*inputs) + + # Get the result + logger.info("Postprocessing image") + result = self._postprocess(pred)[:h, :w, :] + logger.info("Image deblurred successfully") + return result + except Exception as e: + logger.error(f"Error in deblur_image: {str(e)}") + logger.error(traceback.format_exc()) + raise + + def save_image(self, image: np.ndarray, output_path: str) -> str: + """Save an image to the given path.""" + try: + # Convert to BGR for OpenCV + save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + # Save the image + if not os.path.isabs(output_path): + # Use the outputs directory by default + output_path = os.path.join(self.outputs_dir, output_path) + + # Ensure the parent directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + cv2.imwrite(output_path, save_img) + logger.info(f"Image saved to {output_path}") + return output_path + except Exception as e: + logger.error(f"Error saving image: {str(e)}") + logger.error(traceback.format_exc()) + raise + +def main(): + """ + Main function to test the DeblurGAN model. + Processes all images in the inputs directory and saves results to outputs directory. + """ + try: + # Initialize the DeblurGAN model + deblur_model = DeblurGAN() + + # Get the inputs directory + inputs_dir = deblur_model.inputs_dir + outputs_dir = deblur_model.outputs_dir + + # Check if there are any images in the inputs directory + input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) + and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] + + if not input_files: + logger.warning(f"No image files found in {inputs_dir}") + print(f"No image files found in {inputs_dir}. Please add some images and try again.") + return + + logger.info(f"Found {len(input_files)} images to process") + print(f"Found {len(input_files)} images to process") + + # Process each image + for input_file in input_files: + try: + input_path = os.path.join(inputs_dir, input_file) + output_file = f"deblurred_{input_file}" + output_path = os.path.join(outputs_dir, output_file) + + print(f"Processing {input_file}...") + + # Deblur the image + deblurred_img = deblur_model.deblur_image(input_path) + + # Save the deblurred image + deblur_model.save_image(deblurred_img, output_file) + + print(f"✅ Saved deblurred image to {output_path}") + + except Exception as e: + logger.error(f"Error processing {input_file}: {str(e)}") + print(f"❌ Failed to process {input_file}: {str(e)}") + + print(f"\nDeblurring complete! Check {outputs_dir} for results.") + + except Exception as e: + logger.error(f"Error in main function: {str(e)}") + logger.error(traceback.format_exc()) + print(f"❌ Error: {str(e)}") + +if __name__ == "__main__": + main() diff --git a/DeblurGanV2/fpn_inception.h5 b/DeblurGanV2/fpn_inception.h5 new file mode 100644 index 0000000000000000000000000000000000000000..32bb4e1ae80bf56f0d5397f5240a261f153afa48 --- /dev/null +++ b/DeblurGanV2/fpn_inception.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ab85a0419534bca6c8bee65eff3c163313e647b4577d493d679aa31169b0eea +size 244480361 diff --git a/DeblurGanV2/logging_utils.py b/DeblurGanV2/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd2e65e968a3a9ec70d9d18bd282637feabea6f --- /dev/null +++ b/DeblurGanV2/logging_utils.py @@ -0,0 +1,70 @@ +import os +import logging +from logging.handlers import RotatingFileHandler +import sys + +# Create logs directory if it doesn't exist +logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs') +os.makedirs(logs_dir, exist_ok=True) + +# Log file path +log_file_path = os.path.join(logs_dir, 'deblur_app.log') + +# Keep track of configured loggers to prevent duplicate setup +_configured_loggers = set() + +def setup_logger(name=None): + """ + Setup and return a logger configured to log to both file and console + + Args: + name: Logger name (default: root logger) + + Returns: + Configured logger instance + """ + global _configured_loggers + + # Get or create the logger + logger = logging.getLogger(name) + + # Check if this logger has already been configured + logger_id = id(logger) + if logger_id in _configured_loggers or logger.handlers: + return logger + + # Add to configured loggers set + _configured_loggers.add(logger_id) + + # Set log level + logger.setLevel(logging.INFO) + + # Prevent propagation to avoid duplicate logs + logger.propagate = False + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # File handler (with rotation) + file_handler = RotatingFileHandler( + log_file_path, + maxBytes=10*1024*1024, # 10 MB + backupCount=5 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + + # Add handlers to logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + +# Configure root logger only once +if not logging.getLogger().handlers: + setup_logger() \ No newline at end of file diff --git a/DeblurGanV2/models/__init__.py b/DeblurGanV2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/DeblurGanV2/models/fpn_inception.py b/DeblurGanV2/models/fpn_inception.py new file mode 100644 index 0000000000000000000000000000000000000000..29a5f151af4ba0ca8f2c0f3ed64475ba129e0a1a --- /dev/null +++ b/DeblurGanV2/models/fpn_inception.py @@ -0,0 +1,167 @@ +import torch +import torch.nn as nn +from pretrainedmodels import inceptionresnetv2 +from torchsummary import summary +import torch.nn.functional as F + +class FPNHead(nn.Module): + def __init__(self, num_in, num_mid, num_out): + super().__init__() + + self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False) + self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = nn.functional.relu(self.block0(x), inplace=True) + x = nn.functional.relu(self.block1(x), inplace=True) + return x + +class ConvBlock(nn.Module): + def __init__(self, num_in, num_out, norm_layer): + super().__init__() + + self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1), + norm_layer(num_out), + nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.block(x) + return x + + +class FPNInception(nn.Module): + + def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256): + super().__init__() + + # Feature Pyramid Network (FPN) with four feature maps of resolutions + # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. + self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer) + + # The segmentation heads on top of the FPN + + self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters) + self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters) + self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters) + self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters) + + self.smooth = nn.Sequential( + nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1), + norm_layer(num_filters), + nn.ReLU(), + ) + + self.smooth2 = nn.Sequential( + nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1), + norm_layer(num_filters // 2), + nn.ReLU(), + ) + + self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1) + + def unfreeze(self): + self.fpn.unfreeze() + + def forward(self, x): + map0, map1, map2, map3, map4 = self.fpn(x) + + map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") + map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") + map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") + map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest") + + smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1)) + smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") + smoothed = self.smooth2(smoothed + map0) + smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") + + final = self.final(smoothed) + res = torch.tanh(final) + x + + return torch.clamp(res, min = -1,max = 1) + + +class FPN(nn.Module): + + def __init__(self, norm_layer, num_filters=256): + """Creates an `FPN` instance for feature extraction. + Args: + num_filters: the number of filters in each output pyramid level + pretrained: use ImageNet pre-trained backbone feature extractor + """ + + super().__init__() + self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet') + + self.enc0 = self.inception.conv2d_1a + self.enc1 = nn.Sequential( + self.inception.conv2d_2a, + self.inception.conv2d_2b, + self.inception.maxpool_3a, + ) # 64 + self.enc2 = nn.Sequential( + self.inception.conv2d_3b, + self.inception.conv2d_4a, + self.inception.maxpool_5a, + ) # 192 + self.enc3 = nn.Sequential( + self.inception.mixed_5b, + self.inception.repeat, + self.inception.mixed_6a, + ) # 1088 + self.enc4 = nn.Sequential( + self.inception.repeat_1, + self.inception.mixed_7a, + ) #2080 + self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), + norm_layer(num_filters), + nn.ReLU(inplace=True)) + self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), + norm_layer(num_filters), + nn.ReLU(inplace=True)) + self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), + norm_layer(num_filters), + nn.ReLU(inplace=True)) + self.pad = nn.ReflectionPad2d(1) + self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False) + self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False) + self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False) + self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False) + self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False) + + for param in self.inception.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.inception.parameters(): + param.requires_grad = True + + def forward(self, x): + + # Bottom-up pathway, from ResNet + enc0 = self.enc0(x) + + enc1 = self.enc1(enc0) # 256 + + enc2 = self.enc2(enc1) # 512 + + enc3 = self.enc3(enc2) # 1024 + + enc4 = self.enc4(enc3) # 2048 + + # Lateral connections + + lateral4 = self.pad(self.lateral4(enc4)) + lateral3 = self.pad(self.lateral3(enc3)) + lateral2 = self.lateral2(enc2) + lateral1 = self.pad(self.lateral1(enc1)) + lateral0 = self.lateral0(enc0) + + # Top-down pathway + pad = (1, 2, 1, 2) # pad last dim by 1 on each side + pad1 = (0, 1, 0, 1) + map4 = lateral4 + map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")) + map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")) + map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")) + return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4 diff --git a/DeblurGanV2/models/networks.py b/DeblurGanV2/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8bee1d8856ed0074600a396d33bce79d78b42e58 --- /dev/null +++ b/DeblurGanV2/models/networks.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.autograd import Variable +import numpy as np +import logging +from logging_utils import setup_logger + +# Configure logging +logger = setup_logger(__name__) + +# Try to import the necessary modules, use fallback if not available +try: + from models.fpn_inception import FPNInception + INCEPTION_AVAILABLE = True + logger.info("Successfully imported FPNInception model") +except ImportError as e: + logger.error(f"Error importing FPNInception: {str(e)}") + INCEPTION_AVAILABLE = False + +# Simple fallback model for testing purposes +class FallbackDeblurModel(nn.Module): + def __init__(self): + super().__init__() + logger.info("Initializing fallback model for testing") + # Simple autoencoder-like structure + self.encoder = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2) + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), + nn.ReLU(inplace=True), + nn.Conv2d(64, 3, kernel_size=3, padding=1), + nn.Tanh() + ) + + def forward(self, x): + # Simple pass-through for testing + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return torch.clamp(decoded + x, min=-1, max=1) + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + +def get_generator(model_config): + if isinstance(model_config, str): + generator_name = model_config + else: + generator_name = model_config['g_name'] + + # Try to use FPNInception if available + if generator_name == 'fpn_inception': + if INCEPTION_AVAILABLE: + try: + logger.info("Creating FPNInception model") + model_g = FPNInception(norm_layer=get_norm_layer(norm_type='instance')) + return nn.DataParallel(model_g) + except Exception as e: + logger.error(f"Error creating FPNInception model: {str(e)}") + logger.warning("Falling back to simple model for testing") + return FallbackDeblurModel() + else: + logger.warning("FPNInception not available, using fallback model") + return FallbackDeblurModel() + else: + raise ValueError("Generator Network [%s] not recognized." % generator_name) diff --git a/DeblurGanV2/requirements.txt b/DeblurGanV2/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..55d96ff796bf6c0d974dbf37a81550df01c5ba92 --- /dev/null +++ b/DeblurGanV2/requirements.txt @@ -0,0 +1,24 @@ +# Core model dependencies +torch>=1.0.1 +torchvision +numpy +opencv-python-headless +Pillow +PyYAML +pretrainedmodels>=0.7.4 +torchsummary + +# FastAPI dependencies +fastapi>=0.68.0 +python-multipart +uvicorn + +# Streamlit dependencies +streamlit>=1.12.0 +requests + +# Utilities +fire + +# Required for preprocessing +albumentations>=1.0.0 diff --git a/DeblurGanV2/run.py b/DeblurGanV2/run.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b6931b3a02ed3b80a6f16d0edd8c830f89845c --- /dev/null +++ b/DeblurGanV2/run.py @@ -0,0 +1,248 @@ +import os +import subprocess +import sys +import time +import webbrowser +import requests +import signal +import psutil + +def main(): + """Run the Debluring Application by starting both API and Streamlit app""" + # Change to the directory of this script + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + print("\n🚀 Starting Debluring Application...\n") + + # Define URLs + api_url = "http://localhost:8001" + streamlit_url = "http://localhost:8501" + + # Function to start or restart the API server + def start_api_server(): + nonlocal api_process + # Kill any existing uvicorn processes that might be hanging + for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + try: + if 'uvicorn' in str(proc.info['cmdline']).lower() and '8001' in str(proc.info['cmdline']): + print(f"Killing existing uvicorn process (PID: {proc.info['pid']})") + psutil.Process(proc.info['pid']).kill() + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + print("Starting API server...") + return subprocess.Popen( + [sys.executable, "api.py"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + universal_newlines=True + ) + + # Function to check API health + def check_api_health(timeout=2): + try: + response = requests.get(f"{api_url}/status", timeout=timeout) + return response.status_code == 200 + except: + return False + + # Start the API server + api_process = start_api_server() + + # Wait for API to actually be available + print("Waiting for API to start", end="") + api_ready = False + for _ in range(15): # Try for 15 seconds + if check_api_health(): + api_ready = True + print("\n✅ API server is running") + break + print(".", end="", flush=True) + time.sleep(1) + + if not api_ready: + print("\n⚠️ API server might not be fully ready, but continuing anyway") + + # Function to start or restart the Streamlit app + def start_streamlit_app(): + nonlocal streamlit_process + # Kill any existing streamlit processes that might be hanging + for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + try: + if 'streamlit' in str(proc.info['cmdline']).lower() and 'run' in str(proc.info['cmdline']): + print(f"Killing existing streamlit process (PID: {proc.info['pid']})") + psutil.Process(proc.info['pid']).kill() + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + print("Starting Streamlit web interface...") + return subprocess.Popen( + [sys.executable, "-m", "streamlit", "run", "app.py"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + universal_newlines=True + ) + + # Start the Streamlit app + streamlit_process = start_streamlit_app() + + # Give Streamlit a moment to start + time.sleep(3) + print("✅ Streamlit web interface started") + + # Open web browser + try: + print("Opening web interface in your browser...") + webbrowser.open(streamlit_url) + except Exception as e: + print(f"Failed to open browser. Please open this URL manually: {streamlit_url}") + + # Print URLs + print("\n📋 Application URLs:") + print(f" - Web Interface: {streamlit_url}") + print(f" - API: {api_url}") + + # Set up a more graceful exit handler + def handle_exit(signum, frame): + print("\n👋 Shutting down gracefully...") + shutdown_services(api_process, streamlit_process) + sys.exit(0) + + signal.signal(signal.SIGINT, handle_exit) + + print("\n⌨️ Press Ctrl+C to stop the application...") + print("📊 Monitoring services for stability...\n") + + # Track request counts to detect stalled processes + request_count = 0 + last_check_time = time.time() + consecutive_failures = 0 + + try: + # Keep the script running until interrupted + while True: + # Check for crashed processes + api_status = api_process.poll() + streamlit_status = streamlit_process.poll() + + if api_status is not None: + print(f"⚠️ API server stopped unexpectedly (exit code: {api_status})") + # Restart the API server + api_process = start_api_server() + time.sleep(5) # Give it time to start + + if streamlit_status is not None: + print(f"⚠️ Streamlit app stopped unexpectedly (exit code: {streamlit_status})") + # Restart Streamlit + streamlit_process = start_streamlit_app() + time.sleep(5) # Give it time to start + + # More frequent health checks (every 10 seconds) + if time.time() - last_check_time > 10: + request_count += 1 + # Check API health + if check_api_health(): + print(f"✅ Health check #{request_count}: API is responsive") + consecutive_failures = 0 # Reset failure counter on success + else: + consecutive_failures += 1 + print(f"⚠️ Health check #{request_count}: API not responding (failure #{consecutive_failures})") + + # If we have 3 consecutive failures, restart the API + if consecutive_failures >= 3: + print("❌ API has been unresponsive for too long. Restarting...") + # Force terminate the API process + if api_process and api_process.poll() is None: + api_process.terminate() + time.sleep(1) + if api_process.poll() is None: + if sys.platform != 'win32': + os.kill(api_process.pid, signal.SIGKILL) + else: + subprocess.call(['taskkill', '/F', '/T', '/PID', str(api_process.pid)]) + + # Start a new API process + api_process = start_api_server() + time.sleep(5) # Give it time to start + consecutive_failures = 0 # Reset failure counter + + last_check_time = time.time() + + # Check memory usage of processes + try: + api_proc = psutil.Process(api_process.pid) + api_memory = api_proc.memory_info().rss / (1024 * 1024) # Convert to MB + + # If API is using too much memory (>500MB), restart it + if api_memory > 500: + print(f"⚠️ API server using excessive memory: {api_memory:.2f} MB. Restarting...") + # Restart the API + if api_process and api_process.poll() is None: + api_process.terminate() + time.sleep(1) + api_process = start_api_server() + time.sleep(5) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + # Sleep for a bit before checking again + time.sleep(2) + + except KeyboardInterrupt: + # Handle termination via Ctrl+C + print("\n👋 Shutting down...") + except Exception as e: + # Handle any other exceptions + print(f"\n❌ Error: {str(e)}") + finally: + shutdown_services(api_process, streamlit_process) + +def shutdown_services(api_process, streamlit_process): + """Safely shut down all running services""" + # Clean up processes + if api_process and api_process.poll() is None: + print("Stopping API server...") + try: + api_process.terminate() + # Give it a moment to terminate gracefully + time.sleep(1) + # Force kill if still running + if api_process.poll() is None: + if sys.platform != 'win32': + os.kill(api_process.pid, signal.SIGKILL) + else: + subprocess.call(['taskkill', '/F', '/T', '/PID', str(api_process.pid)]) + except: + print("Could not terminate API server gracefully") + + if streamlit_process and streamlit_process.poll() is None: + print("Stopping Streamlit app...") + try: + streamlit_process.terminate() + # Give it a moment to terminate gracefully + time.sleep(1) + # Force kill if still running + if streamlit_process.poll() is None: + if sys.platform != 'win32': + os.kill(streamlit_process.pid, signal.SIGKILL) + else: + subprocess.call(['taskkill', '/F', '/T', '/PID', str(streamlit_process.pid)]) + except: + print("Could not terminate Streamlit app gracefully") + + # Also kill any related processes that might still be hanging + for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + try: + cmdline = str(proc.info['cmdline']).lower() + if ('uvicorn' in cmdline and '8001' in cmdline) or ('streamlit' in cmdline and 'run' in cmdline): + print(f"Killing leftover process: {proc.info['pid']}") + psutil.Process(proc.info['pid']).kill() + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + pass + + print("\n✅ Application stopped") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/NAFNET/README_API.md b/NAFNET/README_API.md new file mode 100644 index 0000000000000000000000000000000000000000..9408b237c1e614768eb1983ae4fddd2303eda7e9 --- /dev/null +++ b/NAFNET/README_API.md @@ -0,0 +1,88 @@ +# NAFNet API and Streamlit App + +This project provides a FastAPI server and Streamlit web application for image deblurring using the NAFNet model. + +## Features + +- REST API for image deblurring +- Web-based Streamlit UI for easy interaction +- High-quality deblurring using the state-of-the-art NAFNet model + +## Prerequisites + +- Python 3.8 or higher +- CUDA-compatible GPU (recommended for faster processing) + +## Installation + +1. Install the required dependencies: + +```bash +pip install -r requirements.txt +``` + +2. Make sure the NAFNet model is properly set up (pretrained weights should be in the appropriate location as specified in your config files). + +## Running the API Server + +The API server provides endpoints for image deblurring. + +```bash +python api.py +``` + +This will start the server on http://localhost:8001 + +### API Endpoints + +- `GET /` - Check if the API is running +- `POST /deblur/` - Upload an image for deblurring +- `GET /status/` - Check the status of the model +- `GET /clear-memory/` - Force clear GPU memory +- `GET /diagnostics/` - Get system information + +## Running the Streamlit App + +The Streamlit app provides a user-friendly web interface. + +```bash +streamlit run app.py +``` + +This will start the Streamlit app and open it in your browser. + +## Using the Command Line + +You can also use NAFNet for deblurring via command line: + +```bash +python basicsr/demo.py -opt options/test/REDS/NAFNet-width64.yml --input_path ./demo/blurry.jpg --output_path ./demo/deblur_img.png +``` + +## Architecture + +This project consists of the following components: + +1. `deblur_module.py` - Core module for image deblurring using NAFNet +2. `api.py` - FastAPI server for exposing the deblurring functionality +3. `app.py` - Streamlit web application for user-friendly interaction + +## Directory Structure + +- `inputs/` - Directory for storing uploaded images +- `outputs/` - Directory for storing processed images +- `options/` - Configuration files for NAFNet model +- `experiments/` - Model weights and experiment results + +## Troubleshooting + +If you encounter issues: + +1. Make sure all dependencies are installed correctly +2. Check if the model weights are in the correct location +3. Ensure you have enough GPU memory for processing +4. Look for error messages in the terminal output + +## License + +See the NAFNet repository for license information. diff --git a/NAFNET/api.py b/NAFNET/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d2357d587e18d071547b54031292c5aa2dbafd --- /dev/null +++ b/NAFNET/api.py @@ -0,0 +1,299 @@ +import os +import uuid +import gc +from typing import Optional +from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks +from fastapi.responses import FileResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import logging +import traceback +import shutil +import torch + +from deblur_module import NAFNetDeblur, setup_logger + +# Configure logging +logger = setup_logger(__name__) + +# Define API URL +API_URL = "http://localhost:8001" + +# Initialize FastAPI app +app = FastAPI( + title="NAFNet Debluring API", + description="API for deblurring images using NAFNet deep learning model", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize NAFNet model +model = None +processing_lock = False + +def get_model(): + global model + if model is None: + logger.info("Initializing NAFNet deblurring model...") + try: + model = NAFNetDeblur() + logger.info("Model initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize model: {str(e)}") + logger.error(traceback.format_exc()) + raise RuntimeError(f"Could not initialize NAFNet model: {str(e)}") + return model + +def cleanup_resources(input_path=None): + """Clean up resources after processing""" + # Force garbage collection + gc.collect() + + # Clean up CUDA memory if available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Remove input file if specified + if input_path and os.path.exists(input_path): + try: + os.remove(input_path) + logger.info(f"Removed temporary input file: {input_path}") + except Exception as e: + logger.warning(f"Could not remove temporary input file: {str(e)}") + + # Release processing lock + global processing_lock + processing_lock = False + logger.info("Resources cleaned up") + +@app.get("/") +async def root(): + return {"message": "NAFNet Debluring API is running"} + +@app.post("/deblur/", response_class=FileResponse) +async def deblur_image(background_tasks: BackgroundTasks, file: UploadFile = File(...)): + """ + Deblur an uploaded image and return the processed image file. + """ + global processing_lock + input_path = None + + # Check if another processing request is currently running + if processing_lock: + logger.warning("Another deblurring request is already in progress") + raise HTTPException( + status_code=429, + detail="Server is busy processing another image. Please try again shortly." + ) + + # Set processing lock + processing_lock = True + + try: + # Validate file type + if not file.content_type.startswith("image/"): + logger.warning(f"Invalid file type: {file.content_type}") + processing_lock = False # Release lock + raise HTTPException(status_code=400, detail="File must be an image") + + logger.info(f"Processing image: {file.filename}, size: {file.size} bytes, type: {file.content_type}") + + # Create input and output directories + module_dir = os.path.dirname(os.path.abspath(__file__)) + input_dir = os.path.join(module_dir, 'inputs') + output_dir = os.path.join(module_dir, 'outputs') + os.makedirs(input_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + # Generate unique filenames + unique_id = uuid.uuid4().hex + input_filename = f"input_{unique_id}.png" + input_path = os.path.join(input_dir, input_filename) + output_filename = f"deblurred_{unique_id}.png" + output_path = os.path.join(output_dir, output_filename) + + try: + # Read file contents first + file_contents = await file.read() + + # Save uploaded file to disk + with open(input_path, "wb") as buffer: + buffer.write(file_contents) + + logger.info(f"Input image saved to: {input_path}") + + # Release file resources immediately + await file.close() + + # Get the model + deblur_model = get_model() + + # Process the image + logger.info("Starting deblurring process...") + deblurred_img = deblur_model.deblur_image(input_path) + + # Save the result + deblur_model.save_image(deblurred_img, output_filename) + + logger.info(f"Image deblurred successfully, saved to: {output_path}") + + # Schedule cleanup after response is sent + background_tasks.add_task(cleanup_resources, input_path) + + # Return the result file + return FileResponse( + output_path, + media_type="image/png", + filename=f"deblurred_{file.filename}" + ) + except Exception as e: + logger.error(f"Error in deblurring process: {str(e)}") + logger.error(traceback.format_exc()) + # Always attempt cleanup on error + cleanup_resources(input_path) + raise HTTPException(status_code=500, detail=f"Deblurring failed: {str(e)}") + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + error_msg = f"Error processing image: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + # Make sure lock is released + processing_lock = False + raise HTTPException(status_code=500, detail=error_msg) + +@app.get("/status/") +async def status(): + """Check API status and model availability.""" + try: + logger.info("Checking model status") + + # Check if we're currently processing + if processing_lock: + return { + "status": "busy", + "model_loaded": True, + "message": "Currently processing an image" + } + + # Otherwise do a full check + deblur_model = get_model() + + # Get memory stats if CUDA is available + memory_info = {} + if torch.cuda.is_available(): + memory_info["cuda_memory_allocated"] = f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB" + memory_info["cuda_memory_reserved"] = f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB" + memory_info["cuda_max_memory"] = f"{torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" + + logger.info("Model is loaded and ready") + return { + "status": "ok", + "model_loaded": True, + "processing": processing_lock, + "memory": memory_info + } + except Exception as e: + error_msg = f"Error checking model status: {str(e)}" + logger.error(error_msg) + return {"status": "error", "model_loaded": False, "error": str(e)} + +@app.get("/clear-memory/") +async def clear_memory(): + """Force clear memory and release resources.""" + try: + # Force garbage collection + gc.collect() + + # Clear CUDA cache if available + if torch.cuda.is_available(): + before = torch.cuda.memory_allocated() / 1024**2 + torch.cuda.empty_cache() + after = torch.cuda.memory_allocated() / 1024**2 + logger.info(f"CUDA memory cleared: {before:.2f} MB → {after:.2f} MB") + + # Reset processing lock + global processing_lock + was_locked = processing_lock + processing_lock = False + + return { + "status": "ok", + "message": "Memory cleared successfully", + "lock_released": was_locked + } + except Exception as e: + error_msg = f"Error clearing memory: {str(e)}" + logger.error(error_msg) + return {"status": "error", "error": str(e)} + +@app.get("/diagnostics/") +async def diagnostics(): + """Get diagnostic information about the system.""" + try: + # Check required components + import platform + import sys + import torch + import cv2 + import numpy as np + + # Collect system information + system_info = { + "platform": platform.platform(), + "python_version": sys.version, + "torch_version": torch.__version__, + "cuda_available": torch.cuda.is_available(), + "opencv_version": cv2.__version__, + "numpy_version": np.__version__, + } + + # Get GPU information if available + if torch.cuda.is_available(): + system_info["cuda_version"] = torch.version.cuda + system_info["cuda_device_count"] = torch.cuda.device_count() + system_info["cuda_current_device"] = torch.cuda.current_device() + system_info["cuda_device_name"] = torch.cuda.get_device_name(0) + system_info["cuda_memory_allocated"] = f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB" + system_info["cuda_memory_reserved"] = f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB" + + # Check model state + if model is not None: + system_info["model_loaded"] = True + else: + system_info["model_loaded"] = False + + # Check disk space + if os.name == 'posix': # Linux/Mac + import shutil + total, used, free = shutil.disk_usage("/") + system_info["disk_total"] = f"{total // (2**30)} GB" + system_info["disk_used"] = f"{used // (2**30)} GB" + system_info["disk_free"] = f"{free // (2**30)} GB" + + return { + "status": "ok", + "system_info": system_info + } + except Exception as e: + error_msg = f"Error gathering diagnostics: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + return {"status": "error", "error": str(e)} + +def run_server(host="0.0.0.0", port=8001): + """Run the FastAPI server""" + uvicorn.run(app, host=host, port=port) + +if __name__ == "__main__": + run_server() \ No newline at end of file diff --git a/NAFNET/app.py b/NAFNET/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0112e94a27166843e50981ea6bd1ec9dbd52a766 --- /dev/null +++ b/NAFNET/app.py @@ -0,0 +1,115 @@ +import streamlit as st +import requests +import os +import sys +from PIL import Image +import io +import time +from pathlib import Path + +# Set API URL +API_URL = "http://localhost:8001" # Local FastAPI server URL + +st.set_page_config( + page_title="NAFNet Image Deblurring", + page_icon="🔍", + layout="wide", +) + +st.title("NAFNet Image Deblurring Application") +st.markdown(""" +Transform your blurry photos into clear, sharp images using the state-of-the-art NAFNet AI model. +Upload an image to get started! +""") + +# File uploader +uploaded_file = st.file_uploader( + "Choose a blurry image...", type=["jpg", "jpeg", "png", "bmp"]) + +# Sidebar controls +with st.sidebar: + st.header("About NAFNet") + + st.markdown(""" + **NAFNet** (Nonlinear Activation Free Network) is a state-of-the-art image restoration model designed for tasks like deblurring. + + Key features: + - High-quality image deblurring + - Fast processing time + - Preservation of image details + """) + + st.markdown("---") + + # Check API status + if st.button("Check API Status"): + try: + response = requests.get(f"{API_URL}/status/", timeout=5) + if response.status_code == 200 and response.json().get("status") == "ok": + st.success("✅ API is running and ready") + + # Display additional info if available + memory_info = response.json().get("memory", {}) + if memory_info: + st.info(f"CUDA Memory: {memory_info.get('cuda_memory_allocated', 'N/A')}") + else: + st.error("❌ API is not responding properly") + except: + st.error("❌ Cannot connect to API") + +# Process when upload is ready +if uploaded_file is not None: + # Display the original image + col1, col2 = st.columns(2) + + with col1: + st.subheader("Original Image") + image = Image.open(uploaded_file) + st.image(image, use_container_width=True) + + # Process image button + process_button = st.button("Deblur Image") + + if process_button: + with st.spinner("Deblurring your image... Please wait."): + try: + # Prepare simplified file structure + files = { + "file": ("image.jpg", uploaded_file.getvalue(), "image/jpeg") + } + + # Send request to API + response = requests.post(f"{API_URL}/deblur/", files=files, timeout=60) + + if response.status_code == 200: + with col2: + st.subheader("Deblurred Result") + deblurred_img = Image.open(io.BytesIO(response.content)) + st.image(deblurred_img, use_column_width=True) + + # Option to download the deblurred image + st.download_button( + label="Download Deblurred Image", + data=response.content, + file_name=f"deblurred_{uploaded_file.name}", + mime="image/png" + ) + else: + try: + error_details = response.json().get('detail', 'Unknown error') + except: + error_details = response.text + + st.error(f"Error: {error_details}") + except Exception as e: + st.error(f"An error occurred: {str(e)}") + +# Footer +st.markdown("---") +st.markdown("Powered by NAFNet - Image Restoration Project") + +def main(): + pass # Streamlit already runs the script from top to bottom + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/NAFNET/basicsr/demo.py b/NAFNET/basicsr/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2d387cb3efd2bba7d13e0be7281f023d064913 --- /dev/null +++ b/NAFNET/basicsr/demo.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import torch + +# from basicsr.data import create_dataloader, create_dataset +from basicsr.models import create_model +from basicsr.train import parse_options +from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite + +# from basicsr.utils import (get_env_info, get_root_logger, get_time_str, +# make_exp_dirs) +# from basicsr.utils.options import dict2str + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=False) + opt['num_gpu'] = torch.cuda.device_count() + + img_path = opt['img_path'].get('input_img') + output_path = opt['img_path'].get('output_img') + + + ## 1. read image + file_client = FileClient('disk') + + img_bytes = file_client.get(img_path, None) + try: + img = imfrombytes(img_bytes, float32=True) + except: + raise Exception("path {} not working".format(img_path)) + + img = img2tensor(img, bgr2rgb=True, float32=True) + + + + ## 2. run inference + opt['dist'] = False + model = create_model(opt) + + model.feed_data(data={'lq': img.unsqueeze(dim=0)}) + + if model.opt['val'].get('grids', False): + model.grids() + + model.test() + + if model.opt['val'].get('grids', False): + model.grids_inverse() + + visuals = model.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + imwrite(sr_img, output_path) + + print(f'inference {img_path} .. finished. saved to {output_path}') + +if __name__ == '__main__': + main() + diff --git a/NAFNET/basicsr/demo_ssr.py b/NAFNET/basicsr/demo_ssr.py new file mode 100644 index 0000000000000000000000000000000000000000..764ff086e035dc8c55b6c88b0b15259096e170c9 --- /dev/null +++ b/NAFNET/basicsr/demo_ssr.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import torch + +# from basicsr.data import create_dataloader, create_dataset +from basicsr.models import create_model +from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite, set_random_seed + +import argparse +from basicsr.utils.options import dict2str, parse +from basicsr.utils.dist_util import get_dist_info, init_dist +import random + +def parse_options(is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument( + '-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + + parser.add_argument('--input_l_path', type=str, required=True, help='The path to the input left image. For stereo image inference only.') + parser.add_argument('--input_r_path', type=str, required=True, help='The path to the input right image. For stereo image inference only.') + parser.add_argument('--output_l_path', type=str, required=True, help='The path to the output left image. For stereo image inference only.') + parser.add_argument('--output_r_path', type=str, required=True, help='The path to the output right image. For stereo image inference only.') + + args = parser.parse_args() + opt = parse(args.opt, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + print('init dist .. ', args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + opt['img_path'] = { + 'input_l': args.input_l_path, + 'input_r': args.input_r_path, + 'output_l': args.output_l_path, + 'output_r': args.output_r_path + } + + return opt + +def imread(img_path): + file_client = FileClient('disk') + img_bytes = file_client.get(img_path, None) + try: + img = imfrombytes(img_bytes, float32=True) + except: + raise Exception("path {} not working".format(img_path)) + + img = img2tensor(img, bgr2rgb=True, float32=True) + return img + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=False) + opt['num_gpu'] = torch.cuda.device_count() + + img_l_path = opt['img_path'].get('input_l') + img_r_path = opt['img_path'].get('input_r') + output_l_path = opt['img_path'].get('output_l') + output_r_path = opt['img_path'].get('output_r') + + ## 1. read image + img_l = imread(img_l_path) + img_r = imread(img_r_path) + img = torch.cat([img_l, img_r], dim=0) + + ## 2. run inference + opt['dist'] = False + model = create_model(opt) + + model.feed_data(data={'lq': img.unsqueeze(dim=0)}) + + if model.opt['val'].get('grids', False): + model.grids() + + model.test() + + if model.opt['val'].get('grids', False): + model.grids_inverse() + + visuals = model.get_current_visuals() + sr_img_l = visuals['result'][:,:3] + sr_img_r = visuals['result'][:,3:] + sr_img_l, sr_img_r = tensor2img([sr_img_l, sr_img_r]) + imwrite(sr_img_l, output_l_path) + imwrite(sr_img_r, output_r_path) + + print(f'inference {img_l_path} .. finished. saved to {output_l_path}') + print(f'inference {img_r_path} .. finished. saved to {output_r_path}') + +if __name__ == '__main__': + main() + diff --git a/NAFNET/basicsr/metrics/__init__.py b/NAFNET/basicsr/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83fc13da5a35bc31d7fed825c0c25988f2fc72ec --- /dev/null +++ b/NAFNET/basicsr/metrics/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_left, calculate_psnr_left, calculate_skimage_ssim, calculate_skimage_ssim_left + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_ssim_left', 'calculate_psnr_left', 'calculate_skimage_ssim', 'calculate_skimage_ssim_left'] diff --git a/NAFNET/basicsr/metrics/fid.py b/NAFNET/basicsr/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..102ba9113b5c77fed5f2c168b6d9035746f33b41 --- /dev/null +++ b/NAFNET/basicsr/metrics/fid.py @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import numpy as np +import torch +import torch.nn as nn +from scipy import linalg +from tqdm import tqdm + +from basicsr.models.archs.inception import InceptionV3 + + +def load_patched_inception_v3(device='cuda', + resize_input=True, + normalize_input=False): + # we may not resize the input, but in [rosinality/stylegan2-pytorch] it + # does resize the input. + inception = InceptionV3([3], + resize_input=resize_input, + normalize_input=normalize_input) + inception = nn.DataParallel(inception).eval().to(device) + return inception + + +@torch.no_grad() +def extract_inception_features(data_generator, + inception, + len_generator=None, + device='cuda'): + """Extract inception features. + + Args: + data_generator (generator): A data generator. + inception (nn.Module): Inception model. + len_generator (int): Length of the data_generator to show the + progressbar. Default: None. + device (str): Device. Default: cuda. + + Returns: + Tensor: Extracted features. + """ + if len_generator is not None: + pbar = tqdm(total=len_generator, unit='batch', desc='Extract') + else: + pbar = None + features = [] + + for data in data_generator: + if pbar: + pbar.update(1) + data = data.to(device) + feature = inception(data)[0].view(data.shape[0], -1) + features.append(feature.to('cpu')) + if pbar: + pbar.close() + features = torch.cat(features, 0) + return features + + +def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + + Args: + mu1 (np.array): The sample mean over activations. + sigma1 (np.array): The covariance matrix over activations for + generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an + representative data set. + sigma2 (np.array): The covariance matrix over activations, + precalculated on an representative data set. + + Returns: + float: The Frechet Distance. + """ + assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, ( + 'Two covariances have different dimensions') + + cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) + + # Product might be almost singular + if not np.isfinite(cov_sqrt).all(): + print('Product of cov matrices is singular. Adding {eps} to diagonal ' + 'of cov estimates') + offset = np.eye(sigma1.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + raise ValueError(f'Imaginary component {m}') + cov_sqrt = cov_sqrt.real + + mean_diff = mu1 - mu2 + mean_norm = mean_diff @ mean_diff + trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) + fid = mean_norm + trace + + return fid diff --git a/NAFNET/basicsr/metrics/metric_util.py b/NAFNET/basicsr/metrics/metric_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9764769282209b0dc142d790dc01f0ffb2ad5e51 --- /dev/null +++ b/NAFNET/basicsr/metrics/metric_util.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/NAFNET/basicsr/metrics/niqe.py b/NAFNET/basicsr/metrics/niqe.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf20190404c664ff96e478b38e7a9b30a54bdcc --- /dev/null +++ b/NAFNET/basicsr/metrics/niqe.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import math +import numpy as np +from scipy.ndimage.filters import convolve +from scipy.special import gamma + +from basicsr.metrics.metric_util import reorder_image, to_y_channel + + +def estimate_aggd_param(block): + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters. + + Args: + block (ndarray): 2D Image block. + + Returns: + tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD + distribution (Estimating the parames in Equation 7 in the paper). + """ + block = block.flatten() + gam = np.arange(0.2, 10.001, 0.001) # len = 9801 + gam_reciprocal = np.reciprocal(gam) + r_gam = np.square(gamma(gam_reciprocal * 2)) / ( + gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + + left_std = np.sqrt(np.mean(block[block < 0]**2)) + right_std = np.sqrt(np.mean(block[block > 0]**2)) + gammahat = left_std / right_std + rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) + rhatnorm = (rhat * (gammahat**3 + 1) * + (gammahat + 1)) / ((gammahat**2 + 1)**2) + array_position = np.argmin((r_gam - rhatnorm)**2) + + alpha = gam[array_position] + beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha)) + return (alpha, beta_l, beta_r) + + +def compute_feature(block): + """Compute features. + + Args: + block (ndarray): 2D Image block. + + Returns: + list: Features with length of 18. + """ + feat = [] + alpha, beta_l, beta_r = estimate_aggd_param(block) + feat.extend([alpha, (beta_l + beta_r) / 2]) + + # distortions disturb the fairly regular structure of natural images. + # This deviation can be captured by analyzing the sample distribution of + # the products of pairs of adjacent coefficients computed along + # horizontal, vertical and diagonal orientations. + shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] + for i in range(len(shifts)): + shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) + # Eq. 8 + mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) + feat.extend([alpha, mean, beta_l, beta_r]) + return feat + + +def niqe(img, + mu_pris_param, + cov_pris_param, + gaussian_window, + block_size_h=96, + block_size_w=96): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + Note that we do not include block overlap height and width, since they are + always 0 in the official implementation. + + For good performance, it is advisable by the official implemtation to + divide the distorted image in to the same size patched as used for the + construction of multivariate Gaussian model. + + Args: + img (ndarray): Input image whose quality needs to be computed. The + image must be a gray or Y (of YCbCr) image with shape (h, w). + Range [0, 255] with float type. + mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian + model calculated on the pristine dataset. + cov_pris_param (ndarray): Covariance of a pre-defined multivariate + Gaussian model calculated on the pristine dataset. + gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the + image. + block_size_h (int): Height of the blocks in to which image is divided. + Default: 96 (the official recommended value). + block_size_w (int): Width of the blocks in to which image is divided. + Default: 96 (the official recommended value). + """ + assert img.ndim == 2, ( + 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + # crop image + h, w = img.shape + num_block_h = math.floor(h / block_size_h) + num_block_w = math.floor(w / block_size_w) + img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w] + + distparam = [] # dist param is actually the multiscale features + for scale in (1, 2): # perform on two scales (1, 2) + mu = convolve(img, gaussian_window, mode='nearest') + sigma = np.sqrt( + np.abs( + convolve(np.square(img), gaussian_window, mode='nearest') - + np.square(mu))) + # normalize, as in Eq. 1 in the paper + img_nomalized = (img - mu) / (sigma + 1) + + feat = [] + for idx_w in range(num_block_w): + for idx_h in range(num_block_h): + # process ecah block + block = img_nomalized[idx_h * block_size_h // + scale:(idx_h + 1) * block_size_h // + scale, idx_w * block_size_w // + scale:(idx_w + 1) * block_size_w // + scale] + feat.append(compute_feature(block)) + + distparam.append(np.array(feat)) + # TODO: matlab bicubic downsample with anti-aliasing + # for simplicity, now we use opencv instead, which will result in + # a slight difference. + if scale == 1: + h, w = img.shape + img = cv2.resize( + img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) + img = img * 255. + + distparam = np.concatenate(distparam, axis=1) + + # fit a MVG (multivariate Gaussian) model to distorted patch features + mu_distparam = np.nanmean(distparam, axis=0) + # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html + distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)] + cov_distparam = np.cov(distparam_no_nan, rowvar=False) + + # compute niqe quality, Eq. 10 in the paper + invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) + quality = np.matmul( + np.matmul((mu_pris_param - mu_distparam), invcov_param), + np.transpose((mu_pris_param - mu_distparam))) + quality = np.sqrt(quality) + + return quality + + +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): + """Calculate NIQE (Natural Image Quality Evaluator) metric. + + Ref: Making a "Completely Blind" Image Quality Analyzer. + This implementation could produce almost the same results as the official + MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + + We use the official params estimated from the pristine dataset. + We use the recommended block size (96, 96) without overlaps. + + Args: + img (ndarray): Input image whose quality needs to be computed. + The input image must be in range [0, 255] with float/int type. + The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order) + If the input order is 'HWC' or 'CHW', it will be converted to gray + or Y (of YCbCr) image according to the ``convert_to`` argument. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the metric calculation. + input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'. + Default: 'y'. + + Returns: + float: NIQE result. + """ + + # we use the official params estimated from the pristine dataset. + niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz') + mu_pris_param = niqe_pris_params['mu_pris_param'] + cov_pris_param = niqe_pris_params['cov_pris_param'] + gaussian_window = niqe_pris_params['gaussian_window'] + + img = img.astype(np.float32) + if input_order != 'HW': + img = reorder_image(img, input_order=input_order) + if convert_to == 'y': + img = to_y_channel(img) + elif convert_to == 'gray': + img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255. + img = np.squeeze(img) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border] + + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) + + return niqe_result diff --git a/NAFNET/basicsr/metrics/niqe_pris_params.npz b/NAFNET/basicsr/metrics/niqe_pris_params.npz new file mode 100644 index 0000000000000000000000000000000000000000..42f06a9a18e6ed8bbf7933bec1477b189ef798de --- /dev/null +++ b/NAFNET/basicsr/metrics/niqe_pris_params.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296 +size 11850 diff --git a/NAFNET/basicsr/metrics/psnr_ssim.py b/NAFNET/basicsr/metrics/psnr_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..317a7b5e324e043cb6bd21b8713ca0012facafcc --- /dev/null +++ b/NAFNET/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,358 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# modified from https://github.com/mayorx/matlab_ssim_pytorch_implementation/blob/main/calc_ssim.py +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from skimage.metrics import structural_similarity +import torch + +def calculate_psnr(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1,2,0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1,2,0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _psnr(img1, img2): + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + max_value = 1. if img1.max() <= 1 else 255. + return 20. * np.log10(max_value / np.sqrt(mse)) + + if img1.ndim == 3 and img1.shape[2] == 6: + l1, r1 = img1[:,:,:3], img1[:,:,3:] + l2, r2 = img2[:,:,:3], img2[:,:,3:] + return (_psnr(l1, l2) + _psnr(r1, r2))/2 + else: + return _psnr(img1, img2) + +def calculate_psnr_left(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False): + assert input_order == 'HWC' + assert crop_border == 0 + + img1 = img1[:,64:,:3] + img2 = img2[:,64:,:3] + return calculate_psnr(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel) + +def _ssim(img1, img2, max_value): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * max_value)**2 + C2 = (0.03 * max_value)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + +def prepare_for_ssim(img, k): + import torch + with torch.no_grad(): + img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() + conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') + conv.weight.requires_grad = False + conv.weight[:, :, :, :] = 1. / (k * k) + + img = conv(img) + + img = img.squeeze(0).squeeze(0) + img = img[0::k, 0::k] + return img.detach().cpu().numpy() + +def prepare_for_ssim_rgb(img, k): + import torch + with torch.no_grad(): + img = torch.from_numpy(img).float() #HxWx3 + + conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') + conv.weight.requires_grad = False + conv.weight[:, :, :, :] = 1. / (k * k) + + new_img = [] + + for i in range(3): + new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) + + return torch.stack(new_img, dim=2).detach().cpu().numpy() + +def _3d_gaussian_calculator(img, conv3d): + out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + return out + +def _generate_3d_gaussian_kernel(): + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + kernel_3 = cv2.getGaussianKernel(11, 1.5) + kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) + conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') + conv3d.weight.requires_grad = False + conv3d.weight[0, 0, :, :, :] = kernel + return conv3d + +def _ssim_3d(img1, img2, max_value): + assert len(img1.shape) == 3 and len(img2.shape) == 3 + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + + Returns: + float: ssim result. + """ + C1 = (0.01 * max_value) ** 2 + C2 = (0.03 * max_value) ** 2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = _generate_3d_gaussian_kernel().cuda() + + img1 = torch.tensor(img1).float().cuda() + img2 = torch.tensor(img2).float().cuda() + + + mu1 = _3d_gaussian_calculator(img1, kernel) + mu2 = _3d_gaussian_calculator(img2, kernel) + + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq + sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq + sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return float(ssim_map.mean()) + +def _ssim_cly(img1, img2): + assert len(img1.shape) == 2 and len(img2.shape) == 2 + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = cv2.getGaussianKernel(11, 1.5) + # print(kernel) + window = np.outer(kernel, kernel.transpose()) + + bt = cv2.BORDER_REPLICATE + + mu1 = cv2.filter2D(img1, -1, window, borderType=bt) + mu2 = cv2.filter2D(img2, -1, window,borderType=bt) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * + (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + ssim3d=True): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1,2,0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1,2,0) + + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + def _cal_ssim(img1, img2): + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + return _ssim_cly(img1[..., 0], img2[..., 0]) + + ssims = [] + # ssims_before = [] + + # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) + # print('.._skimage', + # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) + max_value = 1 if img1.max() <= 1 else 255 + with torch.no_grad(): + final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(img1, img2, max_value) + ssims.append(final_ssim) + + # for i in range(img1.shape[2]): + # ssims_before.append(_ssim(img1, img2)) + + # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) + # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) + + return np.array(ssims).mean() + + if img1.ndim == 3 and img1.shape[2] == 6: + l1, r1 = img1[:,:,:3], img1[:,:,3:] + l2, r2 = img2[:,:,:3], img2[:,:,3:] + return (_cal_ssim(l1, l2) + _cal_ssim(r1, r2))/2 + else: + return _cal_ssim(img1, img2) + +def calculate_ssim_left(img1, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + ssim3d=True): + assert input_order == 'HWC' + assert crop_border == 0 + + img1 = img1[:,64:,:3] + img2 = img2[:,64:,:3] + return calculate_ssim(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel, ssim3d=ssim3d) + +def calculate_skimage_ssim(img1, img2): + return structural_similarity(img1, img2, multichannel=True) + +def calculate_skimage_ssim_left(img1, img2): + img1 = img1[:,64:,:3] + img2 = img2[:,64:,:3] + return calculate_skimage_ssim(img1=img1, img2=img2) diff --git a/NAFNET/basicsr/models/__init__.py b/NAFNET/basicsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9664bfae7f08cb02bbc009f6d24349471e572610 --- /dev/null +++ b/NAFNET/basicsr/models/__init__.py @@ -0,0 +1,48 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import importlib +from os import path as osp + +from basicsr.utils import get_root_logger, scandir + +# automatically scan and import model modules +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [ + osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) + if v.endswith('_model.py') +] +# import all the model modules +_model_modules = [ + importlib.import_module(f'basicsr.models.{file_name}') + for file_name in model_filenames +] + + +def create_model(opt): + """Create model. + + Args: + opt (dict): Configuration. It constains: + model_type (str): Model type. + """ + model_type = opt['model_type'] + + # dynamic instantiation + for module in _model_modules: + model_cls = getattr(module, model_type, None) + if model_cls is not None: + break + if model_cls is None: + raise ValueError(f'Model {model_type} is not found.') + + model = model_cls(opt) + + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/NAFNET/basicsr/models/archs/Baseline_arch.py b/NAFNET/basicsr/models/archs/Baseline_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..29d8a6c99110211f3f3f5f70fd7ea92bbff5668b --- /dev/null +++ b/NAFNET/basicsr/models/archs/Baseline_arch.py @@ -0,0 +1,202 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +''' +Simple Baselines for Image Restoration + +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from basicsr.models.archs.arch_util import LayerNorm2d +from basicsr.models.archs.local_arch import Local_Base + +class BaselineBlock(nn.Module): + def __init__(self, c, DW_Expand=1, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Channel Attention + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + nn.Sigmoid() + ) + + # GELU + self.gelu = nn.GELU() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.gelu(x) + x = x * self.se(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.gelu(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class Baseline(nn.Module): + + def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], dw_expand=1, ffn_expand=2): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class BaselineLocal(Local_Base, Baseline): + def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): + Local_Base.__init__(self) + Baseline.__init__(self, *args, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with torch.no_grad(): + self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) + +if __name__ == '__main__': + img_channel = 3 + width = 32 + + dw_expand = 1 + ffn_expand = 2 + + # enc_blks = [2, 2, 4, 8] + # middle_blk_num = 12 + # dec_blks = [2, 2, 2, 2] + + enc_blks = [1, 1, 1, 28] + middle_blk_num = 1 + dec_blks = [1, 1, 1, 1] + + net = Baseline(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, + enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, dw_expand=dw_expand, ffn_expand=ffn_expand) + + inp_shape = (3, 256, 256) + + from ptflops import get_model_complexity_info + + macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) + + params = float(params[:-3]) + macs = float(macs[:-4]) + + print(macs, params) diff --git a/NAFNET/basicsr/models/archs/NAFNet_arch.py b/NAFNET/basicsr/models/archs/NAFNet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..5735e0963b4b1db46f34807e6607c04e70702e91 --- /dev/null +++ b/NAFNET/basicsr/models/archs/NAFNet_arch.py @@ -0,0 +1,202 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +''' +Simple Baselines for Image Restoration + +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from basicsr.models.archs.arch_util import LayerNorm2d +from basicsr.models.archs.local_arch import Local_Base + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFNet(nn.Module): + + def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]): + super().__init__() + + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class NAFNetLocal(Local_Base, NAFNet): + def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): + Local_Base.__init__(self) + NAFNet.__init__(self, *args, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with torch.no_grad(): + self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) + + +if __name__ == '__main__': + img_channel = 3 + width = 32 + + # enc_blks = [2, 2, 4, 8] + # middle_blk_num = 12 + # dec_blks = [2, 2, 2, 2] + + enc_blks = [1, 1, 1, 28] + middle_blk_num = 1 + dec_blks = [1, 1, 1, 1] + + net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, + enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) + + + inp_shape = (3, 256, 256) + + from ptflops import get_model_complexity_info + + macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) + + params = float(params[:-3]) + macs = float(macs[:-4]) + + print(macs, params) diff --git a/NAFNET/basicsr/models/archs/NAFSSR_arch.py b/NAFNET/basicsr/models/archs/NAFSSR_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d189b0ef3c846a55ce5e06e3949b7ed51156903c --- /dev/null +++ b/NAFNET/basicsr/models/archs/NAFSSR_arch.py @@ -0,0 +1,170 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +''' +NAFSSR: Stereo Image Super-Resolution Using NAFNet + +@InProceedings{Chu2022NAFSSR, + author = {Xiaojie Chu and Liangyu Chen and Wenqing Yu}, + title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet}, + booktitle = {CVPRW}, + year = {2022}, +} +''' + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock +from basicsr.models.archs.arch_util import MySequential +from basicsr.models.archs.local_arch import Local_Base + +class SCAM(nn.Module): + ''' + Stereo Cross Attention Module (SCAM) + ''' + def __init__(self, c): + super().__init__() + self.scale = c ** -0.5 + + self.norm_l = LayerNorm2d(c) + self.norm_r = LayerNorm2d(c) + self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + + def forward(self, x_l, x_r): + Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c + Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed) + + V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c + V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c + + # (B, H, W, c) x (B, H, c, W) -> (B, H, W, W) + attention = torch.matmul(Q_l, Q_r_T) * self.scale + + F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c + F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c + + # scale + F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta + F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma + return x_l + F_r2l, x_r + F_l2r + +class DropPath(nn.Module): + def __init__(self, drop_rate, module): + super().__init__() + self.drop_rate = drop_rate + self.module = module + + def forward(self, *feats): + if self.training and np.random.rand() < self.drop_rate: + return feats + + new_feats = self.module(*feats) + factor = 1. / (1 - self.drop_rate) if self.training else 1. + + if self.training and factor != 1.: + new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)]) + return new_feats + +class NAFBlockSR(nn.Module): + ''' + NAFBlock for Super-Resolution + ''' + def __init__(self, c, fusion=False, drop_out_rate=0.): + super().__init__() + self.blk = NAFBlock(c, drop_out_rate=drop_out_rate) + self.fusion = SCAM(c) if fusion else None + + def forward(self, *feats): + feats = tuple([self.blk(x) for x in feats]) + if self.fusion: + feats = self.fusion(*feats) + return feats + +class NAFNetSR(nn.Module): + ''' + NAFNet for Super-Resolution + ''' + def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False): + super().__init__() + self.dual = dual # dual input for stereo SR (left view, right view) + self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.body = MySequential( + *[DropPath( + drop_path_rate, + NAFBlockSR( + width, + fusion=(fusion_from <= i and i <= fusion_to), + drop_out_rate=drop_out_rate + )) for i in range(num_blks)] + ) + + self.up = nn.Sequential( + nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True), + nn.PixelShuffle(up_scale) + ) + self.up_scale = up_scale + + def forward(self, inp): + inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear') + if self.dual: + inp = inp.chunk(2, dim=1) + else: + inp = (inp, ) + feats = [self.intro(x) for x in inp] + feats = self.body(*feats) + out = torch.cat([self.up(x) for x in feats], dim=1) + out = out + inp_hr + return out + +class NAFSSR(Local_Base, NAFNetSR): + def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs): + Local_Base.__init__(self) + NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs) + + N, C, H, W = train_size + base_size = (int(H * 1.5), int(W * 1.5)) + + self.eval() + with torch.no_grad(): + self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) + +if __name__ == '__main__': + num_blks = 128 + width = 128 + droppath=0.1 + train_size = (1, 6, 30, 90) + + net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath) + + inp_shape = (6, 64, 64) + + from ptflops import get_model_complexity_info + FLOPS = 0 + macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True) + + # params = float(params[:-4]) + print(params) + macs = float(macs[:-4]) + FLOPS / 10 ** 9 + + print('mac', macs, params) + + # from basicsr.models.archs.arch_util import measure_inference_speed + # net = net.cuda() + # data = torch.randn((1, 6, 128, 128)).cuda() + # measure_inference_speed(net, (data,)) + + + + diff --git a/NAFNET/basicsr/models/archs/__init__.py b/NAFNET/basicsr/models/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11cfbdc1f1db6a8c54cbafdf0a3ad86c18115b23 --- /dev/null +++ b/NAFNET/basicsr/models/archs/__init__.py @@ -0,0 +1,52 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import arch modules +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [ + osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) + if v.endswith('_arch.py') +] +# import all the arch modules +_arch_modules = [ + importlib.import_module(f'basicsr.models.archs.{file_name}') + for file_name in arch_filenames +] + + +def dynamic_instantiation(modules, cls_type, opt): + """Dynamically instantiate class. + + Args: + modules (list[importlib modules]): List of modules from importlib + files. + cls_type (str): Class type. + opt (dict): Class initialization kwargs. + + Returns: + class: Instantiated class. + """ + + for module in modules: + cls_ = getattr(module, cls_type, None) + if cls_ is not None: + break + if cls_ is None: + raise ValueError(f'{cls_type} is not found.') + return cls_(**opt) + + +def define_network(opt): + network_type = opt.pop('type') + net = dynamic_instantiation(_arch_modules, network_type, opt) + return net diff --git a/NAFNET/basicsr/models/archs/arch_util.py b/NAFNET/basicsr/models/archs/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..09beabfb4fb14e5323dc2a6a4234bddbd43138c2 --- /dev/null +++ b/NAFNET/basicsr/models/archs/arch_util.py @@ -0,0 +1,350 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from basicsr.utils import get_root_logger + +# try: +# from basicsr.models.ops.dcn import (ModulatedDeformConvPack, +# modulated_deform_conv) +# except ImportError: +# # print('Cannot import dcn. Ignore this warning if dcn is not used. ' +# # 'Otherwise install BasicSR with compiling dcn.') +# + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning( +# f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, +# self.stride, self.padding, self.dilation, +# self.groups, self.deformable_groups) + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +# handle multiple input +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + +import time +def measure_inference_speed(model, data, max_iter=200, log_interval=50): + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + fps = 0 + + # benchmark with 2000 image and take the average + for i in range(max_iter): + + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(*data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Done image [{i + 1:<3}/ {max_iter}], ' + f'fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Overall fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + break + return fps \ No newline at end of file diff --git a/NAFNET/basicsr/models/archs/local_arch.py b/NAFNET/basicsr/models/archs/local_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bc459c64a200b9af2870cc1f405fc1df17c6a83a --- /dev/null +++ b/NAFNET/basicsr/models/archs/local_arch.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AvgPool2d(nn.Module): + def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5, 4, 3, 2, 1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + self.train_size = train_size + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp + ) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + train_size = self.train_size + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] + self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) + self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) + + if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): + return F.adaptive_avg_pool2d(x, 1) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0] >= h and self.kernel_size[1] >= w: + out = F.adaptive_avg_pool2d(x, 1) + else: + r1 = [r for r in self.rs if h % r == 0][0] + r2 = [r for r in self.rs if w % r == 0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) + n, c, h, w = s.shape + k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) + out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) + out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(dim=-1).cumsum_(dim=-2) + s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] + out = s4 + s1 - s2 - s3 + out = out / (k1 * k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + # print(x.shape, self.kernel_size) + pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) + out = torch.nn.functional.pad(out, pad2d, mode='replicate') + + return out + +def replace_layers(model, base_size, train_size, fast_imp, **kwargs): + for n, m in model.named_children(): + if len(list(m.children())) > 0: + ## compound module, go inside it + replace_layers(m, base_size, train_size, fast_imp, **kwargs) + + if isinstance(m, nn.AdaptiveAvgPool2d): + pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) + assert m.output_size == 1 + setattr(model, n, pool) + + +''' +ref. +@article{chu2021tlsc, + title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, + author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, + journal={arXiv preprint arXiv:2112.04491}, + year={2021} +} +''' +class Local_Base(): + def convert(self, *args, train_size, **kwargs): + replace_layers(self, *args, train_size=train_size, **kwargs) + imgs = torch.rand(train_size) + with torch.no_grad(): + self.forward(imgs) diff --git a/NAFNET/basicsr/models/base_model.py b/NAFNET/basicsr/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..91d23feab7f7881f5cabf8fa568af490324245c3 --- /dev/null +++ b/NAFNET/basicsr/models/base_model.py @@ -0,0 +1,356 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import logging +import os +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils.dist_util import master_only + +logger = logging.getLogger('basicsr') + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True + use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True + """ + if self.opt['dist']: + return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image) + else: + return self.nondist_validation(dataloader, current_iter, tb_logger, + save_img, rgb2bgr, use_image) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', + False) + net = DistributedDataParallel( + net, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepRestartLR(optimizer, + **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingRestartLR( + optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'TrueCosineAnnealingLR': + print('..', 'cosineannealingLR') + for optimizer in self.optimizers: + self.schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'LinearLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.LinearLR( + optimizer, train_opt['total_iter'])) + elif scheduler_type == 'VibrateLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.VibrateLR( + optimizer, train_opt['total_iter'])) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = (f'{net.__class__.__name__} - ' + f'{net.module.__class__.__name__}') + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger.info( + f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append( + [v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append( + [v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [ + param_group['lr'] + for param_group in self.optimizers[0].param_groups + ] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len( + param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + torch.save(save_dict, save_path) + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with differnet name or different size when loading models. + + 1. Print keys with differnet names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning( + f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + net = self.get_bare_model(net) + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}.') + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + load_net = load_net[param_key] + print(' load net keys', load_net.keys) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = { + 'epoch': epoch, + 'iter': current_iter, + 'optimizers': [], + 'schedulers': [] + } + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], + save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len( + self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len( + self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/NAFNET/basicsr/models/image_restoration_model.py b/NAFNET/basicsr/models/image_restoration_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1eec564b8ba5a57160e379ca860843bdda4fe47d --- /dev/null +++ b/NAFNET/basicsr/models/image_restoration_model.py @@ -0,0 +1,413 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import importlib +import torch +import torch.nn.functional as F +from collections import OrderedDict +from copy import deepcopy +from os import path as osp +from tqdm import tqdm + +from basicsr.models.archs import define_network +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.dist_util import get_dist_info + +loss_module = importlib.import_module('basicsr.models.losses') +metric_module = importlib.import_module('basicsr.metrics') + +class ImageRestorationModel(BaseModel): + """Base Deblur model for single image deblur.""" + + def __init__(self, opt): + super(ImageRestorationModel, self).__init__(opt) + + # define network + self.net_g = define_network(deepcopy(opt['network_g'])) + self.net_g = self.model_to_device(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g, load_path, + self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params')) + + if self.is_train: + self.init_training_settings() + + self.scale = int(opt['scale']) + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + # define losses + if train_opt.get('pixel_opt'): + pixel_type = train_opt['pixel_opt'].pop('type') + cri_pix_cls = getattr(loss_module, pixel_type) + self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to( + self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + percep_type = train_opt['perceptual_opt'].pop('type') + cri_perceptual_cls = getattr(loss_module, percep_type) + self.cri_perceptual = cri_perceptual_cls( + **train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + # if k.startswith('module.offsets') or k.startswith('module.dcns'): + # optim_params_lowlr.append(v) + # else: + optim_params.append(v) + # else: + # logger = get_root_logger() + # logger.warning(f'Params {k} will not be optimized.') + # print(optim_params) + # ratio = 0.1 + + optim_type = train_opt['optim_g'].pop('type') + if optim_type == 'Adam': + self.optimizer_g = torch.optim.Adam([{'params': optim_params}], + **train_opt['optim_g']) + elif optim_type == 'SGD': + self.optimizer_g = torch.optim.SGD(optim_params, + **train_opt['optim_g']) + elif optim_type == 'AdamW': + self.optimizer_g = torch.optim.AdamW([{'params': optim_params}], + **train_opt['optim_g']) + pass + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data, is_val=False): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def grids(self): + b, c, h, w = self.gt.size() + self.original_size = (b, c, h, w) + + assert b == 1 + if 'crop_size_h' in self.opt['val']: + crop_size_h = self.opt['val']['crop_size_h'] + else: + crop_size_h = int(self.opt['val'].get('crop_size_h_ratio') * h) + + if 'crop_size_w' in self.opt['val']: + crop_size_w = self.opt['val'].get('crop_size_w') + else: + crop_size_w = int(self.opt['val'].get('crop_size_w_ratio') * w) + + + crop_size_h, crop_size_w = crop_size_h // self.scale * self.scale, crop_size_w // self.scale * self.scale + #adaptive step_i, step_j + num_row = (h - 1) // crop_size_h + 1 + num_col = (w - 1) // crop_size_w + 1 + + import math + step_j = crop_size_w if num_col == 1 else math.ceil((w - crop_size_w) / (num_col - 1) - 1e-8) + step_i = crop_size_h if num_row == 1 else math.ceil((h - crop_size_h) / (num_row - 1) - 1e-8) + + scale = self.scale + step_i = step_i//scale*scale + step_j = step_j//scale*scale + + parts = [] + idxes = [] + + i = 0 # 0~h-1 + last_i = False + while i < h and not last_i: + j = 0 + if i + crop_size_h >= h: + i = h - crop_size_h + last_i = True + + last_j = False + while j < w and not last_j: + if j + crop_size_w >= w: + j = w - crop_size_w + last_j = True + parts.append(self.lq[:, :, i // scale :(i + crop_size_h) // scale, j // scale:(j + crop_size_w) // scale]) + idxes.append({'i': i, 'j': j}) + j = j + step_j + i = i + step_i + + self.origin_lq = self.lq + self.lq = torch.cat(parts, dim=0) + self.idxes = idxes + + def grids_inverse(self): + preds = torch.zeros(self.original_size) + b, c, h, w = self.original_size + + count_mt = torch.zeros((b, 1, h, w)) + if 'crop_size_h' in self.opt['val']: + crop_size_h = self.opt['val']['crop_size_h'] + else: + crop_size_h = int(self.opt['val'].get('crop_size_h_ratio') * h) + + if 'crop_size_w' in self.opt['val']: + crop_size_w = self.opt['val'].get('crop_size_w') + else: + crop_size_w = int(self.opt['val'].get('crop_size_w_ratio') * w) + + crop_size_h, crop_size_w = crop_size_h // self.scale * self.scale, crop_size_w // self.scale * self.scale + + for cnt, each_idx in enumerate(self.idxes): + i = each_idx['i'] + j = each_idx['j'] + preds[0, :, i: i + crop_size_h, j: j + crop_size_w] += self.outs[cnt] + count_mt[0, 0, i: i + crop_size_h, j: j + crop_size_w] += 1. + + self.output = (preds / count_mt).to(self.device) + self.lq = self.origin_lq + + def optimize_parameters(self, current_iter, tb_logger): + self.optimizer_g.zero_grad() + + if self.opt['train'].get('mixup', False): + self.mixup_aug() + + preds = self.net_g(self.lq) + if not isinstance(preds, list): + preds = [preds] + + self.output = preds[-1] + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = 0. + for pred in preds: + l_pix += self.cri_pix(pred, self.gt) + + # print('l pix ... ', l_pix) + l_total += l_pix + loss_dict['l_pix'] = l_pix + + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + # + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + + l_total = l_total + 0. * sum(p.sum() for p in self.net_g.parameters()) + + l_total.backward() + use_grad_clip = self.opt['train'].get('use_grad_clip', True) + if use_grad_clip: + torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01) + self.optimizer_g.step() + + + self.log_dict = self.reduce_loss_dict(loss_dict) + + def test(self): + self.net_g.eval() + with torch.no_grad(): + n = len(self.lq) + outs = [] + m = self.opt['val'].get('max_minibatch', n) + i = 0 + while i < n: + j = i + m + if j >= n: + j = n + pred = self.net_g(self.lq[i:j]) + if isinstance(pred, list): + pred = pred[-1] + outs.append(pred.detach().cpu()) + i = j + + self.output = torch.cat(outs, dim=0) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = { + metric: 0 + for metric in self.opt['val']['metrics'].keys() + } + + rank, world_size = get_dist_info() + if rank == 0: + pbar = tqdm(total=len(dataloader), unit='image') + + cnt = 0 + + for idx, val_data in enumerate(dataloader): + if idx % world_size != rank: + continue + + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + + self.feed_data(val_data, is_val=True) + if self.opt['val'].get('grids', False): + self.grids() + + self.test() + + if self.opt['val'].get('grids', False): + self.grids_inverse() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if sr_img.shape[2] == 6: + L_img = sr_img[:, :, :3] + R_img = sr_img[:, :, 3:] + + # visual_dir = osp.join('visual_results', dataset_name, self.opt['name']) + visual_dir = osp.join(self.opt['path']['visualization'], dataset_name) + + imwrite(L_img, osp.join(visual_dir, f'{img_name}_L.png')) + imwrite(R_img, osp.join(visual_dir, f'{img_name}_R.png')) + else: + if self.opt['is_train']: + + save_img_path = osp.join(self.opt['path']['visualization'], + img_name, + f'{img_name}_{current_iter}.png') + + save_gt_img_path = osp.join(self.opt['path']['visualization'], + img_name, + f'{img_name}_{current_iter}_gt.png') + else: + save_img_path = osp.join( + self.opt['path']['visualization'], dataset_name, + f'{img_name}.png') + save_gt_img_path = osp.join( + self.opt['path']['visualization'], dataset_name, + f'{img_name}_gt.png') + + imwrite(sr_img, save_img_path) + imwrite(gt_img, save_gt_img_path) + + if with_metrics: + # calculate metrics + opt_metric = deepcopy(self.opt['val']['metrics']) + if use_image: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(sr_img, gt_img, **opt_) + else: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_) + + cnt += 1 + if rank == 0: + for _ in range(world_size): + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if rank == 0: + pbar.close() + + # current_metric = 0. + collected_metrics = OrderedDict() + if with_metrics: + for metric in self.metric_results.keys(): + collected_metrics[metric] = torch.tensor(self.metric_results[metric]).float().to(self.device) + collected_metrics['cnt'] = torch.tensor(cnt).float().to(self.device) + + self.collected_metrics = collected_metrics + + keys = [] + metrics = [] + for name, value in self.collected_metrics.items(): + keys.append(name) + metrics.append(value) + metrics = torch.stack(metrics, 0) + torch.distributed.reduce(metrics, dst=0) + if self.opt['rank'] == 0: + metrics_dict = {} + cnt = 0 + for key, metric in zip(keys, metrics): + if key == 'cnt': + cnt = float(metric) + continue + metrics_dict[key] = float(metric) + + for key in metrics_dict: + metrics_dict[key] /= cnt + + self._log_validation_metric_values(current_iter, dataloader.dataset.opt['name'], + tb_logger, metrics_dict) + return 0. + + def nondist_validation(self, *args, **kwargs): + logger = get_root_logger() + logger.warning('nondist_validation is not implemented. Run dist_validation.') + self.dist_validation(*args, **kwargs) + + + def _log_validation_metric_values(self, current_iter, dataset_name, + tb_logger, metric_dict): + log_str = f'Validation {dataset_name}, \t' + for metric, value in metric_dict.items(): + log_str += f'\t # {metric}: {value:.4f}' + logger = get_root_logger() + logger.info(log_str) + + log_dict = OrderedDict() + # for name, value in loss_dict.items(): + for metric, value in metric_dict.items(): + log_dict[f'm_{metric}'] = value + + self.log_dict = log_dict + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/NAFNET/basicsr/models/losses/__init__.py b/NAFNET/basicsr/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2db43241187d7b211dd3db7ee2ef1fd934f173 --- /dev/null +++ b/NAFNET/basicsr/models/losses/__init__.py @@ -0,0 +1,11 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +from .losses import (L1Loss, MSELoss, PSNRLoss) + +__all__ = [ + 'L1Loss', 'MSELoss', 'PSNRLoss', +] diff --git a/NAFNET/basicsr/models/losses/loss_util.py b/NAFNET/basicsr/models/losses/loss_util.py new file mode 100644 index 0000000000000000000000000000000000000000..52d7919c41068ef2657ebef523bb6a6e8e7f6a69 --- /dev/null +++ b/NAFNET/basicsr/models/losses/loss_util.py @@ -0,0 +1,101 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/NAFNET/basicsr/models/losses/losses.py b/NAFNET/basicsr/models/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5b688dc9c0ffb2176d1a4e5b63d89a22528e91 --- /dev/null +++ b/NAFNET/basicsr/models/losses/losses.py @@ -0,0 +1,116 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import torch +from torch import nn as nn +from torch.nn import functional as F +import numpy as np + +from basicsr.models.losses.loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +# @weighted_loss +# def charbonnier_loss(pred, target, eps=1e-12): +# return torch.sqrt((pred - target)**2 + eps) + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' + f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss( + pred, target, weight, reduction=self.reduction) + +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' + f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss( + pred, target, weight, reduction=self.reduction) + +class PSNRLoss(nn.Module): + + def __init__(self, loss_weight=1.0, reduction='mean', toY=False): + super(PSNRLoss, self).__init__() + assert reduction == 'mean' + self.loss_weight = loss_weight + self.scale = 10 / np.log(10) + self.toY = toY + self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) + self.first = True + + def forward(self, pred, target): + assert len(pred.size()) == 4 + if self.toY: + if self.first: + self.coef = self.coef.to(pred.device) + self.first = False + + pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. + + pred, target = pred / 255., target / 255. + pass + assert len(pred.size()) == 4 + + return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() + diff --git a/NAFNET/basicsr/models/lr_scheduler.py b/NAFNET/basicsr/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c98233baeb219378f71de1d58de6c1ab2432111 --- /dev/null +++ b/NAFNET/basicsr/models/lr_scheduler.py @@ -0,0 +1,189 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + milestones, + gamma=0.1, + restarts=(0, ), + restart_weights=(1, ), + last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [ + group['initial_lr'] * weight + for group in self.optimizer.param_groups + ] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + +class LinearLR(_LRScheduler): + """ + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + total_iter, + last_epoch=-1): + self.total_iter = total_iter + super(LinearLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + process = self.last_epoch / self.total_iter + weight = (1 - process) + # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) + return [weight * group['initial_lr'] for group in self.optimizer.param_groups] + +class VibrateLR(_LRScheduler): + """ + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + total_iter, + last_epoch=-1): + self.total_iter = total_iter + super(VibrateLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + process = self.last_epoch / self.total_iter + + f = 0.1 + if process < 3 / 8: + f = 1 - process * 8 / 3 + elif process < 5 / 8: + f = 0.2 + + T = self.total_iter // 80 + Th = T // 2 + + t = self.last_epoch % T + + f2 = t / Th + if t >= Th: + f2 = 2 - f2 + + weight = f * f2 + + if self.last_epoch < Th: + weight = max(0.1, weight) + + # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) + return [weight * group['initial_lr'] for group in self.optimizer.param_groups] + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, + optimizer, + periods, + restart_weights=(1, ), + eta_min=0, + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, + self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/NAFNET/basicsr/test.py b/NAFNET/basicsr/test.py new file mode 100644 index 0000000000000000000000000000000000000000..bf04219925867f725da5e83e7fcab6162594eead --- /dev/null +++ b/NAFNET/basicsr/test.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import logging +import torch +from os import path as osp + +from basicsr.data import create_dataloader, create_dataset +from basicsr.models import create_model +from basicsr.train import parse_options +from basicsr.utils import (get_env_info, get_root_logger, get_time_str, + make_exp_dirs) +from basicsr.utils.options import dict2str + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=False) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # mkdir and initialize loggers + make_exp_dirs(opt) + log_file = osp.join(opt['path']['log'], + f"test_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger( + logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # create test dataset and dataloader + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + if 'test' in phase: + dataset_opt['phase'] = 'test' + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader( + test_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=None, + seed=opt['manual_seed']) + logger.info( + f"Number of test images in {dataset_opt['name']}: {len(test_set)}") + test_loaders.append(test_loader) + + # create model + model = create_model(opt) + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') + rgb2bgr = opt['val'].get('rgb2bgr', True) + # wheather use uint8 image to compute metrics + use_image = opt['val'].get('use_image', True) + model.validation( + test_loader, + current_iter=opt['name'], + tb_logger=None, + save_img=opt['val']['save_img'], + rgb2bgr=rgb2bgr, use_image=use_image) + + +if __name__ == '__main__': + main() diff --git a/NAFNET/basicsr/train.py b/NAFNET/basicsr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc8f2a01c00a336d2fbfc0354e84dd48978d3f8 --- /dev/null +++ b/NAFNET/basicsr/train.py @@ -0,0 +1,305 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import argparse +import datetime +import logging +import math +import random +import time +import torch +from os import path as osp + +from basicsr.data import create_dataloader, create_dataset +from basicsr.data.data_sampler import EnlargedSampler +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.models import create_model +from basicsr.utils import (MessageLogger, check_resume, get_env_info, + get_root_logger, get_time_str, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, + set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist +from basicsr.utils.options import dict2str, parse + + +def parse_options(is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument( + '-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + + parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.') + parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.') + + args = parser.parse_args() + opt = parse(args.opt, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + print('init dist .. ', args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + if args.input_path is not None and args.output_path is not None: + opt['img_path'] = { + 'input_img': args.input_path, + 'output_img': args.output_path + } + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], + f"train_{opt['name']}_{get_time_str()}.log") + logger = get_root_logger( + logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') + is not None) and (opt['logger']['wandb'].get('project') + is not None) and ('debug' not in opt['name']): + assert opt['logger'].get('use_tb_logger') is True, ( + 'should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + # tb_logger = init_tb_logger(log_dir=f'./logs/{opt['name']}') #mkdir logs @CLY + tb_logger = init_tb_logger(log_dir=osp.join('logs', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = create_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], + opt['rank'], dataset_enlarge_ratio) + train_loader = create_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / + (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info( + 'Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader( + val_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=None, + seed=opt['manual_seed']) + logger.info( + f'Number of val images/folders in {dataset_opt["name"]}: ' + f'{len(val_set)}') + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def main(): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # automatic resume .. + state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) + import os + try: + states = os.listdir(state_folder_path) + except: + states = [] + + resume_state = None + if len(states) > 0: + print('!!!!!! resume state .. ', states, state_folder_path) + max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) + resume_state = os.path.join(state_folder_path, max_state_file) + opt['path']['resume_state'] = resume_state + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ + 'name'] and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = create_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " + f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + model = create_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' + "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info( + f'Start training from epoch: {start_epoch}, iter: {current_iter}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + # for epoch in range(start_epoch, total_epochs + 1): + epoch = start_epoch + while current_iter <= total_iters: + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate( + current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data, is_val=False) + result_code = model.optimize_parameters(current_iter, tb_logger) + # if result_code == -1 and tb_logger: + # print('loss explode .. ') + # exit(0) + iter_time = time.time() - iter_time + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter, 'total_iter': total_iters} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + log_vars.update(model.get_current_log()) + # print('msg logger .. ', current_iter) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0 or current_iter == 1000): + # if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): + rgb2bgr = opt['val'].get('rgb2bgr', True) + # wheather use uint8 image to compute metrics + use_image = opt['val'].get('use_image', True) + model.validation(val_loader, current_iter, tb_logger, + opt['val']['save_img'], rgb2bgr, use_image ) + log_vars = {'epoch': epoch, 'iter': current_iter, 'total_iter': total_iters} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + epoch += 1 + + # end of epoch + + consumed_time = str( + datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None: + rgb2bgr = opt['val'].get('rgb2bgr', True) + use_image = opt['val'].get('use_image', True) + metric = model.validation(val_loader, current_iter, tb_logger, + opt['val']['save_img'], rgb2bgr, use_image) + # if tb_logger: + # print('xxresult! ', opt['name'], ' ', metric) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + import os + os.environ['GRPC_POLL_STRATEGY']='epoll1' + main() diff --git a/NAFNET/basicsr/utils/__init__.py b/NAFNET/basicsr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a5e1e9a26f1508c1b9b7713a84a191d15a7430 --- /dev/null +++ b/NAFNET/basicsr/utils/__init__.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding +from .logger import (MessageLogger, get_env_info, get_root_logger, + init_tb_logger, init_wandb_logger) +from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, + scandir, scandir_SIDD, set_random_seed, sizeof_fmt) +from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'scandir_SIDD', + 'check_resume', + 'sizeof_fmt', + 'padding', + 'create_lmdb_for_reds', + 'create_lmdb_for_gopro', + 'create_lmdb_for_rain13k', +] diff --git a/NAFNET/basicsr/utils/create_lmdb.py b/NAFNET/basicsr/utils/create_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..6194203ceb2da0d6afc84505325cef7182512f57 --- /dev/null +++ b/NAFNET/basicsr/utils/create_lmdb.py @@ -0,0 +1,133 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import argparse +from os import path as osp + +from basicsr.utils import scandir +from basicsr.utils.lmdb_util import make_lmdb_from_imgs + +def prepare_keys(folder_path, suffix='png'): + """Prepare image path list and keys for DIV2K dataset. + + Args: + folder_path (str): Folder path. + + Returns: + list[str]: Image path list. + list[str]: Key list. + """ + print('Reading image path list ...') + img_path_list = sorted( + list(scandir(folder_path, suffix=suffix, recursive=False))) + keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] + + return img_path_list, keys + +def create_lmdb_for_reds(): + # folder_path = './datasets/REDS/val/sharp_300' + # lmdb_path = './datasets/REDS/val/sharp_300.lmdb' + # img_path_list, keys = prepare_keys(folder_path, 'png') + # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + # + # folder_path = './datasets/REDS/val/blur_300' + # lmdb_path = './datasets/REDS/val/blur_300.lmdb' + # img_path_list, keys = prepare_keys(folder_path, 'jpg') + # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/REDS/train/train_sharp' + lmdb_path = './datasets/REDS/train/train_sharp.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/REDS/train/train_blur_jpeg' + lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + +def create_lmdb_for_gopro(): + folder_path = './datasets/GoPro/train/blur_crops' + lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/GoPro/train/sharp_crops' + lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + # folder_path = './datasets/GoPro/test/target' + # lmdb_path = './datasets/GoPro/test/target.lmdb' + + # img_path_list, keys = prepare_keys(folder_path, 'png') + # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + # folder_path = './datasets/GoPro/test/input' + # lmdb_path = './datasets/GoPro/test/input.lmdb' + + # img_path_list, keys = prepare_keys(folder_path, 'png') + # make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + +def create_lmdb_for_rain13k(): + folder_path = './datasets/Rain13k/train/input' + lmdb_path = './datasets/Rain13k/train/input.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/Rain13k/train/target' + lmdb_path = './datasets/Rain13k/train/target.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + +def create_lmdb_for_SIDD(): + folder_path = './datasets/SIDD/train/input_crops' + lmdb_path = './datasets/SIDD/train/input_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'PNG') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/SIDD/train/gt_crops' + lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'PNG') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + #for val + ''' + + folder_path = './datasets/SIDD/val/input_crops' + lmdb_path = './datasets/SIDD/val/input_crops.lmdb' + mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' + if not osp.exists(folder_path): + os.makedirs(folder_path) + assert osp.exists(mat_path) + data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] + N, B, H ,W, C = data.shape + data = data.reshape(N*B, H, W, C) + for i in tqdm(range(N*B)): + cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = './datasets/SIDD/val/gt_crops' + lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' + mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' + if not osp.exists(folder_path): + os.makedirs(folder_path) + assert osp.exists(mat_path) + data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] + N, B, H ,W, C = data.shape + data = data.reshape(N*B, H, W, C) + for i in tqdm(range(N*B)): + cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) + img_path_list, keys = prepare_keys(folder_path, 'png') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + ''' diff --git a/NAFNET/basicsr/utils/dist_util.py b/NAFNET/basicsr/utils/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3411ebb28f711ea96bf913c6b78116e7bc25fc --- /dev/null +++ b/NAFNET/basicsr/utils/dist_util.py @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/NAFNET/basicsr/utils/download_util.py b/NAFNET/basicsr/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..34ebc2eb9417842802615db1a02a9f80ad2bf838 --- /dev/null +++ b/NAFNET/basicsr/utils/download_util.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +import requests +from tqdm import tqdm + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get( + URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int( + response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, + destination, + file_size=None, + chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() diff --git a/NAFNET/basicsr/utils/face_util.py b/NAFNET/basicsr/utils/face_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e51af20eef619f4f1e957c4140a8cb6b1f6d931a --- /dev/null +++ b/NAFNET/basicsr/utils/face_util.py @@ -0,0 +1,223 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import numpy as np +import os +import torch +from skimage import transform as trans + +from basicsr.utils import imwrite + +try: + import dlib +except ImportError: + print('Please install dlib before testing face restoration.' + 'Reference: https://github.com/davisking/dlib') + + +class FaceRestorationHelper(object): + """Helper for the face restoration pipeline.""" + + def __init__(self, upscale_factor, face_size=512): + self.upscale_factor = upscale_factor + self.face_size = (face_size, face_size) + + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], + [586.77227723, 493.59405941], + [337.91089109, 488.38613861], + [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + # for estimation the 2D similarity transformation + self.similarity_trans = trans.SimilarityTransform() + + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.save_png = True + + def init_dlib(self, detection_path, landmark5_path, landmark68_path): + """Initialize the dlib detectors and predictors.""" + self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) + self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) + self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) + + def free_dlib_gpu_memory(self): + del self.face_detector + del self.shape_predictor_5 + del self.shape_predictor_68 + + def read_input_image(self, img_path): + # self.input_img is Numpy array, (h, w, c) with RGB order + self.input_img = dlib.load_rgb_image(img_path) + + def detect_faces(self, + img_path, + upsample_num_times=1, + only_keep_largest=False): + """ + Args: + img_path (str): Image path. + upsample_num_times (int): Upsamples the image before running the + face detector + + Returns: + int: Number of detected faces. + """ + self.read_input_image(img_path) + det_faces = self.face_detector(self.input_img, upsample_num_times) + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - + det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - + det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + return len(self.det_faces) + + def get_face_landmarks_5(self): + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + return len(self.all_landmarks_5) + + def get_face_landmarks_68(self): + """Get 68 densemarks for cropped images. + + Should only have one face at most in the cropped image. + """ + num_detected_face = 0 + for idx, face in enumerate(self.cropped_faces): + # face detection + det_face = self.face_detector(face, 1) # TODO: can we remove it? + if len(det_face) == 0: + print(f'Cannot find faces in cropped image with index {idx}.') + self.all_landmarks_68.append(None) + else: + if len(det_face) > 1: + print('Detect several faces in the cropped face. Use the ' + ' largest one. Note that it will also cause overlap ' + 'during paste_faces_to_input_image.') + face_areas = [] + for i in range(len(det_face)): + face_area = (det_face[i].rect.right() - + det_face[i].rect.left()) * ( + det_face[i].rect.bottom() - + det_face[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + face_rect = det_face[largest_idx].rect + else: + face_rect = det_face[0].rect + shape = self.shape_predictor_68(face, face_rect) + landmark = np.array([[part.x, part.y] + for part in shape.parts()]) + self.all_landmarks_68.append(landmark) + num_detected_face += 1 + + return num_detected_face + + def warp_crop_faces(self, + save_cropped_path=None, + save_inverse_affine_path=None): + """Get affine matrix, warp and cropped faces. + + Also get inverse affine matrix for post-processing. + """ + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + self.similarity_trans.estimate(landmark, self.face_template) + affine_matrix = self.similarity_trans.params[0:2, :] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + cropped_face = cv2.warpAffine(self.input_img, affine_matrix, + self.face_size) + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path, ext = os.path.splitext(save_cropped_path) + if self.save_png: + save_path = f'{path}_{idx:02d}.png' + else: + save_path = f'{path}_{idx:02d}{ext}' + + imwrite( + cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) + + # get inverse affine matrix + self.similarity_trans.estimate(self.face_template, + landmark * self.upscale_factor) + inverse_affine = self.similarity_trans.params[0:2, :] + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path): + # operate in the BGR order + input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) + h, w, _ = input_img.shape + h_up, w_up = h * self.upscale_factor, w * self.upscale_factor + # simply resize the background + upsample_img = cv2.resize(input_img, (w_up, h_up)) + assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( + 'length of restored_faces and affine_matrices are different.') + for restored_face, inverse_affine in zip(self.restored_faces, + self.inverse_affine_matrices): + inv_restored = cv2.warpAffine(restored_face, inverse_affine, + (w_up, h_up)) + mask = np.ones((*self.face_size, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, + np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), + np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode( + inv_mask_erosion, + np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, + (blur_size + 1, blur_size + 1), 0) + upsample_img = inv_soft_mask * inv_restored_remove_border + ( + 1 - inv_soft_mask) * upsample_img + if self.save_png: + save_path = save_path.replace('.jpg', + '.png').replace('.jpeg', '.png') + imwrite(upsample_img.astype(np.uint8), save_path) + + def clean_all(self): + self.all_landmarks_5 = [] + self.all_landmarks_68 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] diff --git a/NAFNET/basicsr/utils/file_client.py b/NAFNET/basicsr/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0785faf8b8ec4bd33661a6677f4faed166ce5c --- /dev/null +++ b/NAFNET/basicsr/utils/file_client.py @@ -0,0 +1,192 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError( + 'Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, + self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, + db_paths, + client_keys='default', + readonly=True, + lock=False, + readahead=False, + **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ( + 'client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open( + path, + readonly=readonly, + lock=lock, + readahead=readahead, + map_size=8*1024*10485760, + # max_readers=1, + **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' + 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError( + f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/NAFNET/basicsr/utils/flow_util.py b/NAFNET/basicsr/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d03c1d43d947a78c8167a548d777eeef233c67fe --- /dev/null +++ b/NAFNET/basicsr/utils/flow_util.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 +import cv2 +import numpy as np +import os + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, ' + f'its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, ' + 'header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + os.makedirs(filename, exist_ok=True) + cv2.imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [ + quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] + ] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum( + np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError( + f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError( + f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - + min_val) / levels + min_val + + return dequantized_arr diff --git a/NAFNET/basicsr/utils/img_util.py b/NAFNET/basicsr/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..6374861971839e446c88b723c7c8471db4085ca6 --- /dev/null +++ b/NAFNET/basicsr/utils/img_util.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or + (isinstance(tensor, list) + and all(torch.is_tensor(t) for t in tensor))): + raise TypeError( + f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid( + _tensor, nrow=int(math.sqrt(_tensor.size(0))), + normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + elif img_np.shape[2] == 3: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' + f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = { + 'color': cv2.IMREAD_COLOR, + 'grayscale': cv2.IMREAD_GRAYSCALE, + 'unchanged': cv2.IMREAD_UNCHANGED + } + if img_np is None: + raise Exception('None .. !!!') + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + +def padding(img_lq, img_gt, gt_size): + h, w, _ = img_lq.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lq, img_gt + + img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + # print('img_lq', img_lq.shape, img_gt.shape) + return img_lq, img_gt + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [ + v[crop_border:-crop_border, crop_border:-crop_border, ...] + for v in imgs + ] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, + ...] diff --git a/NAFNET/basicsr/utils/lmdb_util.py b/NAFNET/basicsr/utils/lmdb_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3cec7d0d4d72e72a6756d6c25563162de87be9f4 --- /dev/null +++ b/NAFNET/basicsr/utils/lmdb_util.py @@ -0,0 +1,214 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ( + 'img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Total images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async( + read_img_worker, + args=(osp.join(data_path, path), key, compress_level), + callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread( + osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode( + '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker( + osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, + [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, + lmdb_path, + map_size=1024**4, + batch=5000, + compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/NAFNET/basicsr/utils/logger.py b/NAFNET/basicsr/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..704e78c9614f3939f1a1faf05b19282e9444c348 --- /dev/null +++ b/NAFNET/basicsr/utils/logger.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + + +class MessageLogger(): + """Message logger for printing. + + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + total_iter = log_vars.pop('total_iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' + f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger and 'debug' not in self.exp_name: + normed_step = 10000 * (current_iter / total_iter) + normed_step = int(normed_step) + + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, normed_step) + elif k.startswith('m_'): + self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step) + else: + assert 1 == 0 + # else: + # self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init( + id=wandb_id, + resume=resume, + name=opt['name'], + config=opt, + project=project, + sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', + log_level=logging.INFO, + log_file=None): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger.hasHandlers(): + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + logging.basicConfig(format=format_str, level=log_level) + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + file_handler = logging.FileHandler(log_file, 'w') + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + + return logger + + +def get_env_info(): + """Get environment information. + + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg diff --git a/NAFNET/basicsr/utils/matlab_functions.py b/NAFNET/basicsr/utils/matlab_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..7657cf4b0ba24d30a795a18786181ac2fe8b13ce --- /dev/null +++ b/NAFNET/basicsr/utils/matlab_functions.py @@ -0,0 +1,367 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace( + 0, p - 1, p).view(1, p).expand(out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( + in_h, out_h, scale, kernel, kernel_width, antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( + in_w, out_w, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [ + -276.836, 135.576, -222.921 + ] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' + f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' + f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/NAFNET/basicsr/utils/misc.py b/NAFNET/basicsr/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..848b12bbd832beae3431425d50fefae7729a7c3b --- /dev/null +++ b/NAFNET/basicsr/utils/misc.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' + not in key) and ('resume' + not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + +def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + keywords (str | tuple(str), optional): File keywords that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (keywords is not None) and not isinstance(keywords, (str, tuple)): + raise TypeError('"keywords" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, keywords, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if keywords is None: + yield return_path + elif return_path.find(keywords) > 0: + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, keywords=keywords, recursive=recursive) + else: + continue + + return _scandir(dir_path, keywords=keywords, recursive=recursive) + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning( + 'pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or ( + basename not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join( + opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/NAFNET/basicsr/utils/options.py b/NAFNET/basicsr/utils/options.py new file mode 100644 index 0000000000000000000000000000000000000000..2f93a6c090937bbd019a71c9aa542e018f5ac22e --- /dev/null +++ b/NAFNET/basicsr/utils/options.py @@ -0,0 +1,117 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import yaml +from collections import OrderedDict +from os import path as osp + + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + + # datasets + if 'datasets' in opt: + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key + or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + opt['path']['root'] = osp.abspath( + osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', + opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, + 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, + 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/NAFNET/basicsr/version.py b/NAFNET/basicsr/version.py new file mode 100644 index 0000000000000000000000000000000000000000..4751ed9f2e5c06661ba8c1b17a12fbcb73376a7c --- /dev/null +++ b/NAFNET/basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Sun Apr 13 19:21:37 2025 +__version__ = '1.2.0+2b4af71' +short_version = '1.2.0' +version_info = (1, 2, 0) diff --git a/NAFNET/deblur_module.py b/NAFNET/deblur_module.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ea657a1c53f8bb551133292af8f446747a1779 --- /dev/null +++ b/NAFNET/deblur_module.py @@ -0,0 +1,260 @@ +import os +import cv2 +import numpy as np +import torch +import yaml +from typing import Optional, Tuple, Union +from io import BytesIO +from PIL import Image +import logging +import traceback + +from basicsr.models import create_model +from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite +from basicsr.utils.options import parse + +# Configure logging +def setup_logger(name, log_level=logging.INFO): + """Set up logger.""" + logger = logging.getLogger(name) + logger.setLevel(log_level) + + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + return logger + +logger = setup_logger(__name__) + +class NAFNetDeblur: + def __init__(self, config_path: str = 'options/test/REDS/NAFNet-width64.yml'): + """ + Initialize the NAFNet deblurring model. + + Args: + config_path: Path to the model configuration YAML file + """ + try: + logger.info(f"Initializing NAFNet with config: {config_path}") + # Make paths relative to the module directory + module_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.isabs(config_path): + config_path = os.path.join(module_dir, config_path) + + # Check if config file exists + if not os.path.exists(config_path): + error_msg = f"Config file not found: {config_path}" + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # Parse configuration + opt = parse(config_path, is_train=False) + opt["dist"] = False + + # Create model + logger.info("Creating model") + self.model = create_model(opt) + + # Set device + try: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logger.info(f"Using device: {self.device}") + except Exception as e: + logger.warning(f"Failed to set device. Error: {str(e)}") + logger.warning("Using CPU mode") + self.device = torch.device('cpu') + + # Create directories for inputs and outputs + self.inputs_dir = os.path.join(module_dir, 'inputs') + self.outputs_dir = os.path.join(module_dir, 'outputs') + + # Ensure directories exist + os.makedirs(self.inputs_dir, exist_ok=True) + os.makedirs(self.outputs_dir, exist_ok=True) + + logger.info("Model initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize model: {str(e)}") + logger.error(traceback.format_exc()) + raise + + def imread(self, img_path): + """Read an image from file.""" + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + def img2tensor(self, img, bgr2rgb=False, float32=True): + """Convert image to tensor.""" + img = img.astype(np.float32) / 255.0 + return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) + + def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: + """ + Deblur an image. + + Args: + image: Input image as a file path, numpy array, or bytes + + Returns: + Deblurred image as a numpy array + """ + try: + # Handle different input types + if isinstance(image, str): + # Image path + logger.info(f"Loading image from path: {image}") + img = self.imread(image) + if img is None: + raise ValueError(f"Failed to read image from {image}") + elif isinstance(image, bytes): + # Bytes (e.g., from file upload) + logger.info("Loading image from bytes") + nparr = np.frombuffer(image, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if img is None: + # Try using PIL as a fallback + pil_img = Image.open(BytesIO(image)) + img = np.array(pil_img.convert('RGB')) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + elif isinstance(image, np.ndarray): + # Already a numpy array + logger.info("Processing image from numpy array") + img = image.copy() + if img.shape[2] == 3 and img.dtype == np.uint8: + if img[0,0,0] > img[0,0,2]: # Simple BGR check + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + else: + raise ValueError(f"Unsupported image type: {type(image)}") + + # Validate image + if img is None or img.size == 0: + raise ValueError("Image is empty or invalid") + + logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") + + # Ensure image has 3 channels + if len(img.shape) != 3 or img.shape[2] != 3: + raise ValueError(f"Image must have 3 channels, got shape {img.shape}") + + # Resize very large images + max_dim = max(img.shape[0], img.shape[1]) + if max_dim > 2000: + scale_factor = 2000 / max_dim + new_h = int(img.shape[0] * scale_factor) + new_w = int(img.shape[1] * scale_factor) + logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") + img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) + + # Convert to tensor + logger.info("Converting image to tensor") + img_tensor = self.img2tensor(img) + + # Process the image + logger.info("Running inference with model") + with torch.no_grad(): + try: + self.model.feed_data(data={'lq': img_tensor.unsqueeze(dim=0)}) + + if self.model.opt['val'].get('grids', False): + self.model.grids() + + self.model.test() + + if self.model.opt['val'].get('grids', False): + self.model.grids_inverse() + + visuals = self.model.get_current_visuals() + result = tensor2img([visuals['result']]) + except Exception as e: + logger.error(f"Error during model inference: {str(e)}") + logger.error(traceback.format_exc()) + raise + + logger.info("Image deblurred successfully") + return result + except Exception as e: + logger.error(f"Error in deblur_image: {str(e)}") + logger.error(traceback.format_exc()) + raise + + def save_image(self, image: np.ndarray, output_path: str) -> str: + """Save an image to the given path.""" + try: + # Convert to BGR for OpenCV + save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + # Save the image + if not os.path.isabs(output_path): + # Use the outputs directory by default + output_path = os.path.join(self.outputs_dir, output_path) + + # Ensure the parent directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + cv2.imwrite(output_path, save_img) + logger.info(f"Image saved to {output_path}") + return output_path + except Exception as e: + logger.error(f"Error saving image: {str(e)}") + logger.error(traceback.format_exc()) + raise + +def main(): + """ + Main function to test the NAFNet deblurring model. + Processes all images in the inputs directory and saves results to outputs directory. + """ + try: + # Initialize the model + deblur_model = NAFNetDeblur() + + # Get the inputs directory + inputs_dir = deblur_model.inputs_dir + outputs_dir = deblur_model.outputs_dir + + # Check if there are any images in the inputs directory + input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) + and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] + + if not input_files: + logger.warning(f"No image files found in {inputs_dir}") + print(f"No image files found in {inputs_dir}. Please add some images and try again.") + return + + logger.info(f"Found {len(input_files)} images to process") + + # Process each image + for input_file in input_files: + try: + input_path = os.path.join(inputs_dir, input_file) + output_file = f"deblurred_{input_file}" + output_path = os.path.join(outputs_dir, output_file) + + logger.info(f"Processing {input_file}...") + + # Deblur the image + deblurred_img = deblur_model.deblur_image(input_path) + + # Save the result + deblur_model.save_image(deblurred_img, output_path) + + logger.info(f"Saved result to {output_path}") + + except Exception as e: + logger.error(f"Error processing {input_file}: {str(e)}") + logger.error(traceback.format_exc()) + + logger.info("Processing complete!") + + except Exception as e: + logger.error(f"Error in main function: {str(e)}") + logger.error(traceback.format_exc()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/NAFNET/environment.yml b/NAFNET/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..386387b4410182d75d8bbb9d05ccefc332594a9f --- /dev/null +++ b/NAFNET/environment.yml @@ -0,0 +1,31 @@ +name: deblur +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.9 + - pytorch=1.11.0 + - torchvision=0.12.0 + - cudatoolkit=11.3 + - pip + - pip: + - addict==2.4.0 + - future==0.18.2 + - lmdb==1.3.0 + - opencv-python==4.5.5.64 + - Pillow==9.1.0 + - pyyaml==6.0 + - requests + - scikit-image==0.19.2 + - scipy==1.8.0 + - tb-nightly + - tqdm==4.64.0 + - yapf + - numpy==1.21.1 + - ipython==7.21.0 + - matplotlib==3.5.1 + - streamlit + - fastapi + - uvicorn + - python-multipart diff --git a/NAFNET/experiments/pretrained_models/NAFNet-REDS-width64.pth b/NAFNET/experiments/pretrained_models/NAFNet-REDS-width64.pth new file mode 100644 index 0000000000000000000000000000000000000000..5c696b5901b6ab3e3f8b8fafb7304ab243dd5316 --- /dev/null +++ b/NAFNET/experiments/pretrained_models/NAFNet-REDS-width64.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:175fe8b3cdf3abedfbc87769779c3d9f491e05bb2e73ea9d627883f90a4b2df3 +size 271756689 diff --git a/NAFNET/experiments/pretrained_models/README.md b/NAFNET/experiments/pretrained_models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..26367fb94df8d3ed7f13a80d546a8d71abbc35bb --- /dev/null +++ b/NAFNET/experiments/pretrained_models/README.md @@ -0,0 +1,4 @@ +### Pretrained NAFNet Models +--- + +please refer to https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models, and download the pretrained models into ./experiments/pretrained_models diff --git a/NAFNET/predict.py b/NAFNET/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b90ccbc55715aa61b363cb913a7c78d5479492 --- /dev/null +++ b/NAFNET/predict.py @@ -0,0 +1,103 @@ +import torch +import numpy as np +import cv2 +import tempfile +import matplotlib.pyplot as plt +from cog import BasePredictor, Path, Input, BaseModel + +from basicsr.models import create_model +from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite +from basicsr.utils.options import parse + + +class Predictor(BasePredictor): + def setup(self): + opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml" + opt_denoise = parse(opt_path_denoise, is_train=False) + opt_denoise["dist"] = False + + opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml" + opt_deblur = parse(opt_path_deblur, is_train=False) + opt_deblur["dist"] = False + + opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml" + opt_stereo = parse(opt_path_stereo, is_train=False) + opt_stereo["dist"] = False + + self.models = { + "Image Denoising": create_model(opt_denoise), + "Image Debluring": create_model(opt_deblur), + "Stereo Image Super-Resolution": create_model(opt_stereo), + } + + def predict( + self, + task_type: str = Input( + choices=[ + "Image Denoising", + "Image Debluring", + "Stereo Image Super-Resolution", + ], + default="Image Debluring", + description="Choose task type.", + ), + image: Path = Input( + description="Input image. Stereo Image Super-Resolution, upload the left image here.", + ), + image_r: Path = Input( + default=None, + description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo" + " Image Super-Resolution task.", + ), + ) -> Path: + + out_path = Path(tempfile.mkdtemp()) / "output.png" + + model = self.models[task_type] + if task_type == "Stereo Image Super-Resolution": + assert image_r is not None, ( + "Please provide both left and right input image for " + "Stereo Image Super-Resolution task." + ) + + img_l = imread(str(image)) + inp_l = img2tensor(img_l) + img_r = imread(str(image_r)) + inp_r = img2tensor(img_r) + stereo_image_inference(model, inp_l, inp_r, str(out_path)) + + else: + + img_input = imread(str(image)) + inp = img2tensor(img_input) + out_path = Path(tempfile.mkdtemp()) / "output.png" + single_image_inference(model, inp, str(out_path)) + + return out_path + + +def imread(img_path): + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def img2tensor(img, bgr2rgb=False, float32=True): + img = img.astype(np.float32) / 255.0 + return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32) + + +def single_image_inference(model, img, save_path): + model.feed_data(data={"lq": img.unsqueeze(dim=0)}) + + if model.opt["val"].get("grids", False): + model.grids() + + model.test() + + if model.opt["val"].get("grids", False): + model.grids_inverse() + + visuals = model.get_current_visuals() + sr_img = tensor2img([visuals["result"]]) + imwrite(sr_img, save_path) diff --git a/NAFNET/readme.md b/NAFNET/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..9543688fb17339f61c3b7b3f8ee85f0fba46fe20 --- /dev/null +++ b/NAFNET/readme.md @@ -0,0 +1,149 @@ +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/simple-baselines-for-image-restoration/image-deblurring-on-gopro)](https://paperswithcode.com/sota/image-deblurring-on-gopro?p=simple-baselines-for-image-restoration) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/simple-baselines-for-image-restoration/image-denoising-on-sidd)](https://paperswithcode.com/sota/image-denoising-on-sidd?p=simple-baselines-for-image-restoration) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-flickr1024-1)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-flickr1024-1?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-flickr1024-2)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-flickr1024-2?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-kitti2012-2x-1)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-kitti2012-2x-1?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-kitti2012-4x)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-kitti2012-4x?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-kitti2015-2x)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-kitti2015-2x?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-kitti2015-4x)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-kitti2015-4x?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-middlebury-1)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-middlebury-1?p=nafssr-stereo-image-super-resolution-using) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/nafssr-stereo-image-super-resolution-using/stereo-image-super-resolution-on-middlebury)](https://paperswithcode.com/sota/stereo-image-super-resolution-on-middlebury?p=nafssr-stereo-image-super-resolution-using) + +## NAFNet: Nonlinear Activation Free Network for Image Restoration + +The official pytorch implementation of the paper **[Simple Baselines for Image Restoration (ECCV2022)](https://arxiv.org/abs/2204.04676)** + +#### Liangyu Chen\*, Xiaojie Chu\*, Xiangyu Zhang, Jian Sun + +>Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. +>In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. +>To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are **not necessary**: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. + +| NAFNet For Image Denoise | NAFNet For Image Deblur | NAFSSR For Stereo Image Super Resolution | +| :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | +| Denoise | Deblur | StereoSR([NAFSSR](https://github.com/megvii-research/NAFNet/blob/main/docs/StereoSR.md)) | + +![PSNR_vs_MACs](./figures/PSNR_vs_MACs.jpg) + +### News +**2022.08.02** The Baseline, including the pretrained models and train/test configs, are available now. + +**2022.07.03** Related work, [Improving Image Restoration by Revisiting Global Information Aggregation](https://arxiv.org/abs/2112.04491) (TLC, a.k.a TLSC in our paper) is accepted by **ECCV2022** :tada: . Code is available at https://github.com/megvii-research/TLC. + +**2022.07.03** Our [paper](https://arxiv.org/abs/2204.04676) is accepted by **ECCV2022** :tada: + +**2022.06.19** [NAFSSR](https://arxiv.org/abs/2204.08714) (as a challenge winner) is selected for an ORAL presentation at CVPR 2022, NTIRE workshop :tada: [Presentation video](https://drive.google.com/file/d/16w33zrb3UI0ZIhvvdTvGB2MP01j0zJve/view), [slides](https://data.vision.ee.ethz.ch/cvl/ntire22/slides/Chu_NAFSSR_slides.pdf) and [poster](https://data.vision.ee.ethz.ch/cvl/ntire22/posters/Chu_NAFSSR_poster.pdf) are available now. + +**2022.04.15** NAFNet based Stereo Image Super-Resolution solution ([NAFSSR](https://arxiv.org/abs/2204.08714)) won the **1st place** on the NTIRE 2022 Stereo Image Super-resolution Challenge! Training/Evaluation instructions see [here](https://github.com/megvii-research/NAFNet/blob/main/docs/StereoSR.md). + +### Installation +This implementation based on [BasicSR](https://github.com/xinntao/BasicSR) which is a open source toolbox for image/video restoration tasks and [HINet](https://github.com/megvii-model/HINet) + +```python +python 3.9.5 +pytorch 1.11.0 +cuda 11.3 +``` + +``` +git clone https://github.com/megvii-research/NAFNet +cd NAFNet +pip install -r requirements.txt +python setup.py develop --no_cuda_ext +``` + +### Quick Start +* Image Denoise Colab Demo: [google colab logo](https://colab.research.google.com/drive/1dkO5AyktmBoWwxBwoKFUurIDn0m4qDXT?usp=sharing) +* Image Deblur Colab Demo: [google colab logo](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing) +* Stereo Image Super-Resolution Colab Demo: [google colab logo](https://colab.research.google.com/drive/1PkLog2imf7jCOPKq1G32SOISz0eLLJaO?usp=sharing) +* Single Image Inference Demo: + * Image Denoise: + ``` + python basicsr/demo.py -opt options/test/SIDD/NAFNet-width64.yml --input_path ./demo/noisy.png --output_path ./demo/denoise_img.png + ``` + * Image Deblur: + ``` + python basicsr/demo.py -opt options/test/REDS/NAFNet-width64.yml --input_path ./demo/blurry.jpg --output_path ./demo/deblur_img.png + ``` + * ```--input_path```: the path of the degraded image + * ```--output_path```: the path to save the predicted image + * [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded. + * Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo for single image restoration[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/chuxiaojie/NAFNet) +* Stereo Image Inference Demo: + * Stereo Image Super-resolution: + ``` + python basicsr/demo_ssr.py -opt options/test/NAFSSR/NAFSSR-L_4x.yml \ + --input_l_path ./demo/lr_img_l.png --input_r_path ./demo/lr_img_r.png \ + --output_l_path ./demo/sr_img_l.png --output_r_path ./demo/sr_img_r.png + ``` + * ```--input_l_path```: the path of the degraded left image + * ```--input_r_path```: the path of the degraded right image + * ```--output_l_path```: the path to save the predicted left image + * ```--output_r_path```: the path to save the predicted right image + * [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded. + * Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo for stereo image super-resolution[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/chuxiaojie/NAFSSR) +* Try the web demo with all three tasks here: [![Replicate](https://replicate.com/megvii-research/nafnet/badge)](https://replicate.com/megvii-research/nafnet) + +### Results and Pre-trained Models + +| name | Dataset|PSNR|SSIM| pretrained models | configs | +|:----|:----|:----|:----|:----|-----| +|NAFNet-GoPro-width32|GoPro|32.8705|0.9606|[gdrive](https://drive.google.com/file/d/1Fr2QadtDCEXg6iwWX8OzeZLbHOx2t5Bj/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1AbgG0yoROHmrRQN7dgzDvQ?pwd=so6v)|[train](./options/train/GoPro/NAFNet-width32.yml) \| [test](./options/test/GoPro/NAFNet-width32.yml)| +|NAFNet-GoPro-width64|GoPro|33.7103|0.9668|[gdrive](https://drive.google.com/file/d/1S0PVRbyTakYY9a82kujgZLbMihfNBLfC/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1g-E1x6En-PbYXm94JfI1vg?pwd=wnwh)|[train](./options/train/GoPro/NAFNet-width64.yml) \| [test](./options/test/GoPro/NAFNet-width64.yml)| +|NAFNet-SIDD-width32|SIDD|39.9672|0.9599|[gdrive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97)|[train](./options/train/SIDD/NAFNet-width32.yml) \| [test](./options/test/SIDD/NAFNet-width32.yml)| +|NAFNet-SIDD-width64|SIDD|40.3045|0.9614|[gdrive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton)|[train](./options/train/SIDD/NAFNet-width64.yml) \| [test](./options/test/SIDD/NAFNet-width64.yml)| +|NAFNet-REDS-width64|REDS|29.0903|0.8671|[gdrive](https://drive.google.com/file/d/14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1vg89ccbpIxg3mK9IONBfGg?pwd=9fas)|[train](./options/train/REDS/NAFNet-width64.yml) \| [test](./options/test/REDS/NAFNet-width64.yml)| +|NAFSSR-L_4x|Flickr1024|24.17|0.7589|[gdrive](https://drive.google.com/file/d/1TIdQhPtBrZb2wrBdAp9l8NHINLeExOwb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1P8ioEuI1gwydA2Avr3nUvw?pwd=qs7a)|[train](./options/train/NAFSSR/NAFSSR-L_4x.yml) \| [test](./options/test/NAFSSR/NAFSSR-L_4x.yml)| +|NAFSSR-L_2x|Flickr1024|29.68|0.9221|[gdrive](https://drive.google.com/file/d/1SZ6bQVYTVS_AXedBEr-_mBCC-qGYHLmf/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1GS6YQSSECH8hAKhvzw6GyQ?pwd=2v3v)|[train](./options/train/NAFSSR/NAFSSR-L_2x.yml) \| [test](./options/test/NAFSSR/NAFSSR-L_2x.yml)| +|Baseline-GoPro-width32|GoPro|32.4799|0.9575|[gdrive](https://drive.google.com/file/d/14z7CxRzVkYEhFgsZg79GlPTEr3VFIGyl/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1WnFKYTAQyAQ9XuD5nlHw_Q?pwd=oieh)|[train](./options/train/GoPro/Baseline-width32.yml) \| [test](./options/test/GoPro/Baseline-width32.yml)| +|Baseline-GoPro-width64|GoPro|33.3960|0.9649|[gdrive](https://drive.google.com/file/d/1yy0oPNJjJxfaEmO0pfPW_TpeoCotYkuO/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1Fqi2T4nyF_wo4wh1QpgIGg?pwd=we36)|[train](./options/train/GoPro/Baseline-width64.yml) \| [test](./options/test/GoPro/Baseline-width64.yml)| +|Baseline-SIDD-width32|SIDD|39.8857|0.9596|[gdrive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin)|[train](./options/train/SIDD/Baseline-width32.yml) \| [test](./options/test/SIDD/Baseline-width32.yml)| +|Baseline-SIDD-width64|SIDD|40.2970|0.9617|[gdrive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8)|[train](./options/train/SIDD/Baseline-width64.yml) \| [test](./options/test/SIDD/Baseline-width64.yml)| + + +### Image Restoration Tasks + +| Task | Dataset | Train/Test Instructions | Visualization Results | +| :----------------------------------- | :------ | :---------------------- | :----------------------------------------------------------- | +| Image Deblurring | GoPro | [link](./docs/GoPro.md) | [gdrive](https://drive.google.com/file/d/1S8u4TqQP6eHI81F9yoVR0be-DLh4cNgb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1yNYQhznChafsbcfHO44aHQ?pwd=96ii)| +| Image Denoising | SIDD | [link](./docs/SIDD.md) | [gdrive](https://drive.google.com/file/d/1rbBYD64bfvbHOrN3HByNg0vz6gHQq7Np/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1wIubY6SeXRfZHpp6bAojqQ?pwd=hu4t)| +| Image Deblurring with JPEG artifacts | REDS | [link](./docs/REDS.md) | [gdrive](https://drive.google.com/file/d/1FwHWYPXdPtUkPqckpz-WBitpVyPuXFRi/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/17T30w5xAtBQQ2P3wawLiVA?pwd=put5) | +| Stereo Image Super-Resolution | Flickr1024+Middlebury | [link](./docs/StereoSR.md) | [gdrive](https://drive.google.com/drive/folders/1lTKe2TU7F-KcU-oaF8jqgoUwIMb6RW0w?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1kov6ivrSFy1FuToCATbyrA?pwd=q263 ) | + + +### Citations +If NAFNet helps your research or work, please consider citing NAFNet. + +``` +@article{chen2022simple, + title={Simple Baselines for Image Restoration}, + author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, + journal={arXiv preprint arXiv:2204.04676}, + year={2022} +} +``` +If NAFSSR helps your research or work, please consider citing NAFSSR. +``` +@InProceedings{chu2022nafssr, + author = {Chu, Xiaojie and Chen, Liangyu and Yu, Wenqing}, + title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, + month = {June}, + year = {2022}, + pages = {1239-1248} +} +``` + +### Contact + +If you have any questions, please contact chenliangyu@megvii.com or chuxiaojie@megvii.com + +--- + +
+statistics + +![visitors](https://visitor-badge.glitch.me/badge?page_id=megvii-research/NAFNet) + +
+ diff --git a/NAFNET/run_services.py b/NAFNET/run_services.py new file mode 100644 index 0000000000000000000000000000000000000000..cb54bf7eba6dec0a651a351e4a49b68df5977797 --- /dev/null +++ b/NAFNET/run_services.py @@ -0,0 +1,63 @@ +import os +import subprocess +import time +import signal +import sys + +def run_services(): + """ + Run both API and Streamlit app as separate processes. + """ + print("Starting NAFNet services...") + + # Set environment variables to suppress warnings + os.environ["PYTHONWARNINGS"] = "ignore" + + # Define commands + api_cmd = ["python", "api.py"] + streamlit_cmd = ["streamlit", "run", "app.py", "--server.port=8501"] + + # Start API server + print("Starting API server on port 8001...") + api_process = subprocess.Popen(api_cmd) + + # Wait for API to start + print("Waiting for API to initialize (5 seconds)...") + time.sleep(5) + + # Start Streamlit app + print("Starting Streamlit app on port 8501...") + streamlit_process = subprocess.Popen(streamlit_cmd) + + print("\n" + "="*50) + print("Services started successfully!") + print("API running at: http://localhost:8001") + print("Streamlit app running at: http://localhost:8501") + print("="*50 + "\n") + + # Handle graceful shutdown on Ctrl+C + def signal_handler(sig, frame): + print("\nShutting down services...") + streamlit_process.terminate() + api_process.terminate() + + # Wait for processes to terminate + streamlit_process.wait() + api_process.wait() + + print("Services stopped.") + sys.exit(0) + + # Register signal handler + signal.signal(signal.SIGINT, signal_handler) + + try: + # Keep the script running + while True: + time.sleep(1) + except KeyboardInterrupt: + # Handle keyboard interrupt (Ctrl+C) + signal_handler(None, None) + +if __name__ == "__main__": + run_services() \ No newline at end of file diff --git a/NAFNET/setup.py b/NAFNET/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..ab13069c4e28899953026c698d094f7b6e263724 --- /dev/null +++ b/NAFNET/setup.py @@ -0,0 +1,182 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +import torch +from torch.utils.cpp_extension import (BuildExtension, CppExtension, + CUDAExtension) + +version_file = 'basicsr/version.py' + + +def readme(): + return '' + # with open('README.md', encoding='utf-8') as f: + # content = f.read() + # return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from basicsr.version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +short_version = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join( + [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + VERSION = SHORT_VERSION + '+' + sha + + version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, + VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + return [] + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--no_cuda_ext' in sys.argv: + ext_modules = [] + sys.argv.remove('--no_cuda_ext') + else: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='basicsr.models.ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=[ + 'src/deform_conv_cuda.cpp', + 'src/deform_conv_cuda_kernel.cu' + ]), + make_cuda_ext( + name='fused_act_ext', + module='basicsr.models.ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='basicsr.models.ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + + write_version_py() + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + packages=find_packages( + exclude=('options', 'datasets', 'experiments', 'results', + 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False)