| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import yaml |
| from typing import Optional, Tuple, Union |
| from io import BytesIO |
| from PIL import Image |
| import logging |
| import traceback |
|
|
| from aug import get_normalize |
| from models.networks import get_generator |
| from logging_utils import setup_logger |
|
|
| |
| logger = setup_logger(__name__) |
|
|
| class DeblurGAN: |
| def __init__(self, weights_path: str = 'fpn_inception.h5', model_name: str = ''): |
| """ |
| Initialize the DeblurGAN model. |
| |
| Args: |
| weights_path: Path to model weights file |
| model_name: Name of the model architecture (if empty, will be read from config) |
| """ |
| try: |
| logger.info(f"Initializing DeblurGAN with weights: {weights_path}") |
| |
| module_dir = os.path.dirname(os.path.abspath(__file__)) |
| config_path = os.path.join(module_dir, 'config/config.yaml') |
| if not os.path.isabs(weights_path): |
| weights_path = os.path.join(module_dir, weights_path) |
| |
| |
| if not os.path.exists(weights_path): |
| error_msg = f"Weights file not found: {weights_path}" |
| logger.error(error_msg) |
| raise FileNotFoundError(error_msg) |
| |
| |
| logger.info(f"Loading configuration from {config_path}") |
| with open(config_path, encoding='utf-8') as cfg: |
| config = yaml.load(cfg, Loader=yaml.FullLoader) |
| |
| |
| logger.info(f"Creating model with architecture: {model_name or config['model']['g_name']}") |
| model = get_generator(model_name or config['model']['g_name']) |
| |
| logger.info("Loading model weights") |
| model.load_state_dict(torch.load(weights_path)['model']) |
| |
| |
| try: |
| self.model = model.cuda() |
| self.device = 'cuda' |
| logger.info("Model moved to CUDA successfully") |
| except Exception as e: |
| logger.warning(f"Failed to move model to CUDA. Error: {str(e)}") |
| logger.warning("Using CPU mode") |
| self.model = model |
| self.device = 'cpu' |
| |
| self.model.train(True) |
| self.normalize_fn = get_normalize() |
| |
| |
| module_dir = os.path.dirname(os.path.abspath(__file__)) |
| self.inputs_dir = os.path.join(module_dir, 'inputs') |
| self.outputs_dir = os.path.join(module_dir, 'outputs') |
| |
| |
| os.makedirs(self.inputs_dir, exist_ok=True) |
| os.makedirs(self.outputs_dir, exist_ok=True) |
| |
| logger.info("Model initialized successfully") |
| except Exception as e: |
| logger.error(f"Failed to initialize model: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| @staticmethod |
| def _array_to_batch(x): |
| """Convert numpy array to batch tensor""" |
| x = np.transpose(x, (2, 0, 1)) |
| x = np.expand_dims(x, 0) |
| return torch.from_numpy(x) |
|
|
| def _preprocess(self, x: np.ndarray) -> Tuple: |
| """Preprocess the input image for the model.""" |
| |
| x, _ = self.normalize_fn(x, x) |
| mask = np.ones_like(x, dtype=np.float32) |
|
|
| |
| h, w, _ = x.shape |
| block_size = 32 |
| min_height = (h // block_size + 1) * block_size |
| min_width = (w // block_size + 1) * block_size |
|
|
| pad_params = { |
| 'mode': 'constant', |
| 'constant_values': 0, |
| 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0)) |
| } |
| x = np.pad(x, **pad_params) |
| mask = np.pad(mask, **pad_params) |
|
|
| return map(self._array_to_batch, (x, mask)), h, w |
|
|
| @staticmethod |
| def _postprocess(x: torch.Tensor) -> np.ndarray: |
| """Convert the model output tensor to a numpy array.""" |
| x, = x |
| x = x.detach().cpu().float().numpy() |
| x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0 |
| return x.astype('uint8') |
|
|
| def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray: |
| """ |
| Deblur an image. |
| |
| Args: |
| image: Input image as a file path, numpy array, or bytes |
| |
| Returns: |
| Deblurred image as a numpy array |
| """ |
| try: |
| |
| if isinstance(image, str): |
| |
| logger.info(f"Loading image from path: {image}") |
| img = cv2.imread(image) |
| if img is None: |
| raise ValueError(f"Failed to read image from {image}") |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| elif isinstance(image, bytes): |
| |
| logger.info("Loading image from bytes") |
| nparr = np.frombuffer(image, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| if img is None: |
| |
| pil_img = Image.open(BytesIO(image)) |
| img = np.array(pil_img.convert('RGB')) |
| else: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| elif isinstance(image, np.ndarray): |
| |
| logger.info("Processing image from numpy array") |
| img = image.copy() |
| if img.shape[2] == 3 and img.dtype == np.uint8: |
| if img[0,0,0] > img[0,0,2]: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| else: |
| raise ValueError(f"Unsupported image type: {type(image)}") |
|
|
| |
| if img is None or img.size == 0: |
| raise ValueError("Image is empty or invalid") |
| |
| logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}") |
| |
| |
| if len(img.shape) != 3 or img.shape[2] != 3: |
| raise ValueError(f"Image must have 3 channels, got shape {img.shape}") |
| |
| |
| max_dim = max(img.shape[0], img.shape[1]) |
| if max_dim > 2000: |
| scale_factor = 2000 / max_dim |
| new_h = int(img.shape[0] * scale_factor) |
| new_w = int(img.shape[1] * scale_factor) |
| logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}") |
| img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
| |
| logger.info("Preprocessing image") |
| (img_batch, mask_batch), h, w = self._preprocess(img) |
| |
| logger.info("Running inference with model") |
| with torch.no_grad(): |
| try: |
| |
| inputs = [img_batch.to(self.device)] |
| pred = self.model(*inputs) |
| except Exception as e: |
| |
| logger.warning(f"Error using {self.device}: {str(e)}. Falling back to CPU.") |
| if self.device == 'cuda': |
| torch.cuda.empty_cache() |
| inputs = [img_batch.to('cpu')] |
| self.model = self.model.to('cpu') |
| self.device = 'cpu' |
| pred = self.model(*inputs) |
| |
| |
| logger.info("Postprocessing image") |
| result = self._postprocess(pred)[:h, :w, :] |
| logger.info("Image deblurred successfully") |
| return result |
| except Exception as e: |
| logger.error(f"Error in deblur_image: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def save_image(self, image: np.ndarray, output_path: str) -> str: |
| """Save an image to the given path.""" |
| try: |
| |
| save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| |
| |
| if not os.path.isabs(output_path): |
| |
| output_path = os.path.join(self.outputs_dir, output_path) |
| |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| cv2.imwrite(output_path, save_img) |
| logger.info(f"Image saved to {output_path}") |
| return output_path |
| except Exception as e: |
| logger.error(f"Error saving image: {str(e)}") |
| logger.error(traceback.format_exc()) |
| raise |
|
|
| def main(): |
| """ |
| Main function to test the DeblurGAN model. |
| Processes all images in the inputs directory and saves results to outputs directory. |
| """ |
| try: |
| |
| deblur_model = DeblurGAN() |
| |
| |
| inputs_dir = deblur_model.inputs_dir |
| outputs_dir = deblur_model.outputs_dir |
| |
| |
| input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) |
| and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))] |
| |
| if not input_files: |
| logger.warning(f"No image files found in {inputs_dir}") |
| print(f"No image files found in {inputs_dir}. Please add some images and try again.") |
| return |
| |
| logger.info(f"Found {len(input_files)} images to process") |
| print(f"Found {len(input_files)} images to process") |
| |
| |
| for input_file in input_files: |
| try: |
| input_path = os.path.join(inputs_dir, input_file) |
| output_file = f"deblurred_{input_file}" |
| output_path = os.path.join(outputs_dir, output_file) |
| |
| print(f"Processing {input_file}...") |
| |
| |
| deblurred_img = deblur_model.deblur_image(input_path) |
| |
| |
| deblur_model.save_image(deblurred_img, output_file) |
| |
| print(f"✅ Saved deblurred image to {output_path}") |
| |
| except Exception as e: |
| logger.error(f"Error processing {input_file}: {str(e)}") |
| print(f"❌ Failed to process {input_file}: {str(e)}") |
| |
| print(f"\nDeblurring complete! Check {outputs_dir} for results.") |
| |
| except Exception as e: |
| logger.error(f"Error in main function: {str(e)}") |
| logger.error(traceback.format_exc()) |
| print(f"❌ Error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|