import os import cv2 import numpy as np import torch import yaml from typing import Optional, Tuple, Union from io import BytesIO from PIL import Image import logging import traceback from aug import get_normalize from models.networks import get_generator from logging_utils import setup_logger # Configure logging logger = setup_logger(__name__) class DeblurGAN: def __init__(self, weights_path: str = 'fpn_inception.h5', model_name: str = ''): """ Initialize the DeblurGAN model. Args: weights_path: Path to model weights file model_name: Name of the model architecture (if empty, will be read from config) """ try: logger.info(f"Initializing DeblurGAN with weights: {weights_path}") # Make paths relative to the module directory module_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(module_dir, 'config/config.yaml') if not os.path.isabs(weights_path): weights_path = os.path.join(module_dir, weights_path) # Check if weights file exists if not os.path.exists(weights_path): error_msg = f"Weights file not found: {weights_path}" logger.error(error_msg) raise FileNotFoundError(error_msg) # Load configuration logger.info(f"Loading configuration from {config_path}") with open(config_path, encoding='utf-8') as cfg: config = yaml.load(cfg, Loader=yaml.FullLoader) # Initialize model logger.info(f"Creating model with architecture: {model_name or config['model']['g_name']}") model = get_generator(model_name or config['model']['g_name']) logger.info("Loading model weights") model.load_state_dict(torch.load(weights_path)['model']) # Try CUDA first, fall back to CPU if necessary try: self.model = model.cuda() self.device = 'cuda' logger.info("Model moved to CUDA successfully") except Exception as e: logger.warning(f"Failed to move model to CUDA. Error: {str(e)}") logger.warning("Using CPU mode") self.model = model self.device = 'cpu' self.model.train(True) # GAN inference uses train mode for batch norm stats self.normalize_fn = get_normalize() # Create directories for inputs and outputs module_dir = os.path.dirname(os.path.abspath(__file__)) self.inputs_dir = os.path.join(module_dir, 'inputs') self.outputs_dir = os.path.join(module_dir, 'outputs') # Ensure directories exist os.makedirs(self.inputs_dir, exist_ok=True) os.makedirs(self.outputs_dir, exist_ok=True) logger.info("Model initialized successfully") except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") logger.error(traceback.format_exc()) raise @staticmethod def _array_to_batch(x): """Convert numpy array to batch tensor""" x = np.transpose(x, (2, 0, 1)) x = np.expand_dims(x, 0) return torch.from_numpy(x) def _preprocess(self, x: np.ndarray) -> Tuple: """Preprocess the input image for the model.""" # Normalize x, _ = self.normalize_fn(x, x) mask = np.ones_like(x, dtype=np.float32) # Pad to be divisible by block_size h, w, _ = x.shape block_size = 32 min_height = (h // block_size + 1) * block_size min_width = (w // block_size + 1) * block_size pad_params = { 'mode': 'constant', 'constant_values': 0, 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0)) } x = np.pad(x, **pad_params) mask = np.pad(mask, **pad_params) return map(self._array_to_batch, (x, mask)), h, w @staticmethod def _postprocess(x: torch.Tensor) -> np.ndarray: """Convert the model output tensor to a numpy array.""" x, = x x = x.detach().cpu().float().numpy() x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0 return x.astype('uint8') def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: """ Deblur an image. Args: image: Input image as a file path, numpy array, or bytes Returns: Deblurred image as a numpy array """ try: # Handle different input types if isinstance(image, str): # Image path logger.info(f"Loading image from path: {image}") img = cv2.imread(image) if img is None: raise ValueError(f"Failed to read image from {image}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) elif isinstance(image, bytes): # Bytes (e.g., from file upload) logger.info("Loading image from bytes") nparr = np.frombuffer(image, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is None: # Try using PIL as a fallback pil_img = Image.open(BytesIO(image)) img = np.array(pil_img.convert('RGB')) else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) elif isinstance(image, np.ndarray): # Already a numpy array logger.info("Processing image from numpy array") img = image.copy() if img.shape[2] == 3 and img.dtype == np.uint8: if img[0,0,0] > img[0,0,2]: # Simple BGR check img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) else: raise ValueError(f"Unsupported image type: {type(image)}") # Validate image if img is None or img.size == 0: raise ValueError("Image is empty or invalid") logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") # Ensure image has 3 channels if len(img.shape) != 3 or img.shape[2] != 3: raise ValueError(f"Image must have 3 channels, got shape {img.shape}") # Resize very large images max_dim = max(img.shape[0], img.shape[1]) if max_dim > 2000: scale_factor = 2000 / max_dim new_h = int(img.shape[0] * scale_factor) new_w = int(img.shape[1] * scale_factor) logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) # Process the image logger.info("Preprocessing image") (img_batch, mask_batch), h, w = self._preprocess(img) logger.info("Running inference with model") with torch.no_grad(): try: # Try to use the device that was set during initialization inputs = [img_batch.to(self.device)] pred = self.model(*inputs) except Exception as e: # If device fails, fall back to CPU logger.warning(f"Error using {self.device}: {str(e)}. Falling back to CPU.") if self.device == 'cuda': torch.cuda.empty_cache() # Free GPU memory inputs = [img_batch.to('cpu')] self.model = self.model.to('cpu') self.device = 'cpu' pred = self.model(*inputs) # Get the result logger.info("Postprocessing image") result = self._postprocess(pred)[:h, :w, :] logger.info("Image deblurred successfully") return result except Exception as e: logger.error(f"Error in deblur_image: {str(e)}") logger.error(traceback.format_exc()) raise def save_image(self, image: np.ndarray, output_path: str) -> str: """Save an image to the given path.""" try: # Convert to BGR for OpenCV save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Save the image if not os.path.isabs(output_path): # Use the outputs directory by default output_path = os.path.join(self.outputs_dir, output_path) # Ensure the parent directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) cv2.imwrite(output_path, save_img) logger.info(f"Image saved to {output_path}") return output_path except Exception as e: logger.error(f"Error saving image: {str(e)}") logger.error(traceback.format_exc()) raise def main(): """ Main function to test the DeblurGAN model. Processes all images in the inputs directory and saves results to outputs directory. """ try: # Initialize the DeblurGAN model deblur_model = DeblurGAN() # Get the inputs directory inputs_dir = deblur_model.inputs_dir outputs_dir = deblur_model.outputs_dir # Check if there are any images in the inputs directory input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] if not input_files: logger.warning(f"No image files found in {inputs_dir}") print(f"No image files found in {inputs_dir}. Please add some images and try again.") return logger.info(f"Found {len(input_files)} images to process") print(f"Found {len(input_files)} images to process") # Process each image for input_file in input_files: try: input_path = os.path.join(inputs_dir, input_file) output_file = f"deblurred_{input_file}" output_path = os.path.join(outputs_dir, output_file) print(f"Processing {input_file}...") # Deblur the image deblurred_img = deblur_model.deblur_image(input_path) # Save the deblurred image deblur_model.save_image(deblurred_img, output_file) print(f"✅ Saved deblurred image to {output_path}") except Exception as e: logger.error(f"Error processing {input_file}: {str(e)}") print(f"❌ Failed to process {input_file}: {str(e)}") print(f"\nDeblurring complete! Check {outputs_dir} for results.") except Exception as e: logger.error(f"Error in main function: {str(e)}") logger.error(traceback.format_exc()) print(f"❌ Error: {str(e)}") if __name__ == "__main__": main()