Image-Debluring / DeblurGanV2 /deblur_module.py
sayed99's picture
initialized both deblurer
61d360d
import os
import cv2
import numpy as np
import torch
import yaml
from typing import Optional, Tuple, Union
from io import BytesIO
from PIL import Image
import logging
import traceback
from aug import get_normalize
from models.networks import get_generator
from logging_utils import setup_logger
# Configure logging
logger = setup_logger(__name__)
class DeblurGAN:
def __init__(self, weights_path: str = 'fpn_inception.h5', model_name: str = ''):
"""
Initialize the DeblurGAN model.
Args:
weights_path: Path to model weights file
model_name: Name of the model architecture (if empty, will be read from config)
"""
try:
logger.info(f"Initializing DeblurGAN with weights: {weights_path}")
# Make paths relative to the module directory
module_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(module_dir, 'config/config.yaml')
if not os.path.isabs(weights_path):
weights_path = os.path.join(module_dir, weights_path)
# Check if weights file exists
if not os.path.exists(weights_path):
error_msg = f"Weights file not found: {weights_path}"
logger.error(error_msg)
raise FileNotFoundError(error_msg)
# Load configuration
logger.info(f"Loading configuration from {config_path}")
with open(config_path, encoding='utf-8') as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)
# Initialize model
logger.info(f"Creating model with architecture: {model_name or config['model']['g_name']}")
model = get_generator(model_name or config['model']['g_name'])
logger.info("Loading model weights")
model.load_state_dict(torch.load(weights_path)['model'])
# Try CUDA first, fall back to CPU if necessary
try:
self.model = model.cuda()
self.device = 'cuda'
logger.info("Model moved to CUDA successfully")
except Exception as e:
logger.warning(f"Failed to move model to CUDA. Error: {str(e)}")
logger.warning("Using CPU mode")
self.model = model
self.device = 'cpu'
self.model.train(True) # GAN inference uses train mode for batch norm stats
self.normalize_fn = get_normalize()
# Create directories for inputs and outputs
module_dir = os.path.dirname(os.path.abspath(__file__))
self.inputs_dir = os.path.join(module_dir, 'inputs')
self.outputs_dir = os.path.join(module_dir, 'outputs')
# Ensure directories exist
os.makedirs(self.inputs_dir, exist_ok=True)
os.makedirs(self.outputs_dir, exist_ok=True)
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
logger.error(traceback.format_exc())
raise
@staticmethod
def _array_to_batch(x):
"""Convert numpy array to batch tensor"""
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def _preprocess(self, x: np.ndarray) -> Tuple:
"""Preprocess the input image for the model."""
# Normalize
x, _ = self.normalize_fn(x, x)
mask = np.ones_like(x, dtype=np.float32)
# Pad to be divisible by block_size
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {
'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(self._array_to_batch, (x, mask)), h, w
@staticmethod
def _postprocess(x: torch.Tensor) -> np.ndarray:
"""Convert the model output tensor to a numpy array."""
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray:
"""
Deblur an image.
Args:
image: Input image as a file path, numpy array, or bytes
Returns:
Deblurred image as a numpy array
"""
try:
# Handle different input types
if isinstance(image, str):
# Image path
logger.info(f"Loading image from path: {image}")
img = cv2.imread(image)
if img is None:
raise ValueError(f"Failed to read image from {image}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif isinstance(image, bytes):
# Bytes (e.g., from file upload)
logger.info("Loading image from bytes")
nparr = np.frombuffer(image, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
# Try using PIL as a fallback
pil_img = Image.open(BytesIO(image))
img = np.array(pil_img.convert('RGB'))
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif isinstance(image, np.ndarray):
# Already a numpy array
logger.info("Processing image from numpy array")
img = image.copy()
if img.shape[2] == 3 and img.dtype == np.uint8:
if img[0,0,0] > img[0,0,2]: # Simple BGR check
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Validate image
if img is None or img.size == 0:
raise ValueError("Image is empty or invalid")
logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}")
# Ensure image has 3 channels
if len(img.shape) != 3 or img.shape[2] != 3:
raise ValueError(f"Image must have 3 channels, got shape {img.shape}")
# Resize very large images
max_dim = max(img.shape[0], img.shape[1])
if max_dim > 2000:
scale_factor = 2000 / max_dim
new_h = int(img.shape[0] * scale_factor)
new_w = int(img.shape[1] * scale_factor)
logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}")
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Process the image
logger.info("Preprocessing image")
(img_batch, mask_batch), h, w = self._preprocess(img)
logger.info("Running inference with model")
with torch.no_grad():
try:
# Try to use the device that was set during initialization
inputs = [img_batch.to(self.device)]
pred = self.model(*inputs)
except Exception as e:
# If device fails, fall back to CPU
logger.warning(f"Error using {self.device}: {str(e)}. Falling back to CPU.")
if self.device == 'cuda':
torch.cuda.empty_cache() # Free GPU memory
inputs = [img_batch.to('cpu')]
self.model = self.model.to('cpu')
self.device = 'cpu'
pred = self.model(*inputs)
# Get the result
logger.info("Postprocessing image")
result = self._postprocess(pred)[:h, :w, :]
logger.info("Image deblurred successfully")
return result
except Exception as e:
logger.error(f"Error in deblur_image: {str(e)}")
logger.error(traceback.format_exc())
raise
def save_image(self, image: np.ndarray, output_path: str) -> str:
"""Save an image to the given path."""
try:
# Convert to BGR for OpenCV
save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Save the image
if not os.path.isabs(output_path):
# Use the outputs directory by default
output_path = os.path.join(self.outputs_dir, output_path)
# Ensure the parent directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cv2.imwrite(output_path, save_img)
logger.info(f"Image saved to {output_path}")
return output_path
except Exception as e:
logger.error(f"Error saving image: {str(e)}")
logger.error(traceback.format_exc())
raise
def main():
"""
Main function to test the DeblurGAN model.
Processes all images in the inputs directory and saves results to outputs directory.
"""
try:
# Initialize the DeblurGAN model
deblur_model = DeblurGAN()
# Get the inputs directory
inputs_dir = deblur_model.inputs_dir
outputs_dir = deblur_model.outputs_dir
# Check if there are any images in the inputs directory
input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f))
and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
if not input_files:
logger.warning(f"No image files found in {inputs_dir}")
print(f"No image files found in {inputs_dir}. Please add some images and try again.")
return
logger.info(f"Found {len(input_files)} images to process")
print(f"Found {len(input_files)} images to process")
# Process each image
for input_file in input_files:
try:
input_path = os.path.join(inputs_dir, input_file)
output_file = f"deblurred_{input_file}"
output_path = os.path.join(outputs_dir, output_file)
print(f"Processing {input_file}...")
# Deblur the image
deblurred_img = deblur_model.deblur_image(input_path)
# Save the deblurred image
deblur_model.save_image(deblurred_img, output_file)
print(f"✅ Saved deblurred image to {output_path}")
except Exception as e:
logger.error(f"Error processing {input_file}: {str(e)}")
print(f"❌ Failed to process {input_file}: {str(e)}")
print(f"\nDeblurring complete! Check {outputs_dir} for results.")
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
logger.error(traceback.format_exc())
print(f"❌ Error: {str(e)}")
if __name__ == "__main__":
main()