File size: 11,555 Bytes
61d360d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import cv2
import numpy as np
import torch
import yaml
from typing import Optional, Tuple, Union
from io import BytesIO
from PIL import Image
import logging
import traceback

from aug import get_normalize
from models.networks import get_generator
from logging_utils import setup_logger

# Configure logging
logger = setup_logger(__name__)

class DeblurGAN:
    def __init__(self, weights_path: str = 'fpn_inception.h5', model_name: str = ''):
        """
        Initialize the DeblurGAN model.
        
        Args:
            weights_path: Path to model weights file
            model_name: Name of the model architecture (if empty, will be read from config)
        """
        try:
            logger.info(f"Initializing DeblurGAN with weights: {weights_path}")
            # Make paths relative to the module directory
            module_dir = os.path.dirname(os.path.abspath(__file__))
            config_path = os.path.join(module_dir, 'config/config.yaml')
            if not os.path.isabs(weights_path):
                weights_path = os.path.join(module_dir, weights_path)
                
            # Check if weights file exists
            if not os.path.exists(weights_path):
                error_msg = f"Weights file not found: {weights_path}"
                logger.error(error_msg)
                raise FileNotFoundError(error_msg)
                
            # Load configuration
            logger.info(f"Loading configuration from {config_path}")
            with open(config_path, encoding='utf-8') as cfg:
                config = yaml.load(cfg, Loader=yaml.FullLoader)
                
            # Initialize model
            logger.info(f"Creating model with architecture: {model_name or config['model']['g_name']}")
            model = get_generator(model_name or config['model']['g_name'])
            
            logger.info("Loading model weights")
            model.load_state_dict(torch.load(weights_path)['model'])
                
            # Try CUDA first, fall back to CPU if necessary
            try:
                self.model = model.cuda()
                self.device = 'cuda'
                logger.info("Model moved to CUDA successfully")
            except Exception as e:
                logger.warning(f"Failed to move model to CUDA. Error: {str(e)}")
                logger.warning("Using CPU mode")
                self.model = model
                self.device = 'cpu'
                
            self.model.train(True)  # GAN inference uses train mode for batch norm stats
            self.normalize_fn = get_normalize()
            
            # Create directories for inputs and outputs
            module_dir = os.path.dirname(os.path.abspath(__file__))
            self.inputs_dir = os.path.join(module_dir, 'inputs')
            self.outputs_dir = os.path.join(module_dir, 'outputs')
            
            # Ensure directories exist
            os.makedirs(self.inputs_dir, exist_ok=True)
            os.makedirs(self.outputs_dir, exist_ok=True)
            
            logger.info("Model initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize model: {str(e)}")
            logger.error(traceback.format_exc())
            raise

    @staticmethod
    def _array_to_batch(x):
        """Convert numpy array to batch tensor"""
        x = np.transpose(x, (2, 0, 1))
        x = np.expand_dims(x, 0)
        return torch.from_numpy(x)

    def _preprocess(self, x: np.ndarray) -> Tuple:
        """Preprocess the input image for the model."""
        # Normalize
        x, _ = self.normalize_fn(x, x)
        mask = np.ones_like(x, dtype=np.float32)

        # Pad to be divisible by block_size
        h, w, _ = x.shape
        block_size = 32
        min_height = (h // block_size + 1) * block_size
        min_width = (w // block_size + 1) * block_size

        pad_params = {
            'mode': 'constant',
            'constant_values': 0,
            'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
        }
        x = np.pad(x, **pad_params)
        mask = np.pad(mask, **pad_params)

        return map(self._array_to_batch, (x, mask)), h, w

    @staticmethod
    def _postprocess(x: torch.Tensor) -> np.ndarray:
        """Convert the model output tensor to a numpy array."""
        x, = x
        x = x.detach().cpu().float().numpy()
        x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
        return x.astype('uint8')

    def deblur_image(self, image: Union[str, np.ndarray, bytes]) -> np.ndarray:
        """
        Deblur an image.
        
        Args:
            image: Input image as a file path, numpy array, or bytes
            
        Returns:
            Deblurred image as a numpy array
        """
        try:
            # Handle different input types
            if isinstance(image, str):
                # Image path
                logger.info(f"Loading image from path: {image}")
                img = cv2.imread(image)
                if img is None:
                    raise ValueError(f"Failed to read image from {image}")
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            elif isinstance(image, bytes):
                # Bytes (e.g., from file upload)
                logger.info("Loading image from bytes")
                nparr = np.frombuffer(image, np.uint8)
                img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
                if img is None:
                    # Try using PIL as a fallback
                    pil_img = Image.open(BytesIO(image))
                    img = np.array(pil_img.convert('RGB'))
                else:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            elif isinstance(image, np.ndarray):
                # Already a numpy array
                logger.info("Processing image from numpy array")
                img = image.copy()
                if img.shape[2] == 3 and img.dtype == np.uint8:
                    if img[0,0,0] > img[0,0,2]:  # Simple BGR check
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            else:
                raise ValueError(f"Unsupported image type: {type(image)}")

            # Validate image
            if img is None or img.size == 0:
                raise ValueError("Image is empty or invalid")
                
            logger.info(f"Image shape: {img.shape}, dtype: {img.dtype}")
            
            # Ensure image has 3 channels
            if len(img.shape) != 3 or img.shape[2] != 3:
                raise ValueError(f"Image must have 3 channels, got shape {img.shape}")
                
            # Resize very large images
            max_dim = max(img.shape[0], img.shape[1])
            if max_dim > 2000:
                scale_factor = 2000 / max_dim
                new_h = int(img.shape[0] * scale_factor)
                new_w = int(img.shape[1] * scale_factor)
                logger.warning(f"Image too large, resizing from {img.shape[:2]} to {(new_h, new_w)}")
                img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)

            # Process the image
            logger.info("Preprocessing image")
            (img_batch, mask_batch), h, w = self._preprocess(img)
            
            logger.info("Running inference with model")
            with torch.no_grad():
                try:
                    # Try to use the device that was set during initialization
                    inputs = [img_batch.to(self.device)]
                    pred = self.model(*inputs)
                except Exception as e:
                    # If device fails, fall back to CPU
                    logger.warning(f"Error using {self.device}: {str(e)}. Falling back to CPU.")
                    if self.device == 'cuda':
                        torch.cuda.empty_cache()  # Free GPU memory
                    inputs = [img_batch.to('cpu')]
                    self.model = self.model.to('cpu')
                    self.device = 'cpu'
                    pred = self.model(*inputs)
                
            # Get the result
            logger.info("Postprocessing image")
            result = self._postprocess(pred)[:h, :w, :]
            logger.info("Image deblurred successfully")
            return result
        except Exception as e:
            logger.error(f"Error in deblur_image: {str(e)}")
            logger.error(traceback.format_exc())
            raise

    def save_image(self, image: np.ndarray, output_path: str) -> str:
        """Save an image to the given path."""
        try:
            # Convert to BGR for OpenCV
            save_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            
            # Save the image
            if not os.path.isabs(output_path):
                # Use the outputs directory by default
                output_path = os.path.join(self.outputs_dir, output_path)
                
            # Ensure the parent directory exists
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
                
            cv2.imwrite(output_path, save_img)
            logger.info(f"Image saved to {output_path}")
            return output_path
        except Exception as e:
            logger.error(f"Error saving image: {str(e)}")
            logger.error(traceback.format_exc())
            raise

def main():
    """
    Main function to test the DeblurGAN model.
    Processes all images in the inputs directory and saves results to outputs directory.
    """
    try:
        # Initialize the DeblurGAN model
        deblur_model = DeblurGAN()
        
        # Get the inputs directory
        inputs_dir = deblur_model.inputs_dir
        outputs_dir = deblur_model.outputs_dir
        
        # Check if there are any images in the inputs directory
        input_files = [f for f in os.listdir(inputs_dir) if os.path.isfile(os.path.join(inputs_dir, f)) 
                      and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        if not input_files:
            logger.warning(f"No image files found in {inputs_dir}")
            print(f"No image files found in {inputs_dir}. Please add some images and try again.")
            return
        
        logger.info(f"Found {len(input_files)} images to process")
        print(f"Found {len(input_files)} images to process")
        
        # Process each image
        for input_file in input_files:
            try:
                input_path = os.path.join(inputs_dir, input_file)
                output_file = f"deblurred_{input_file}"
                output_path = os.path.join(outputs_dir, output_file)
                
                print(f"Processing {input_file}...")
                
                # Deblur the image
                deblurred_img = deblur_model.deblur_image(input_path)
                
                # Save the deblurred image
                deblur_model.save_image(deblurred_img, output_file)
                
                print(f"✅ Saved deblurred image to {output_path}")
                
            except Exception as e:
                logger.error(f"Error processing {input_file}: {str(e)}")
                print(f"❌ Failed to process {input_file}: {str(e)}")
        
        print(f"\nDeblurring complete! Check {outputs_dir} for results.")
        
    except Exception as e:
        logger.error(f"Error in main function: {str(e)}")
        logger.error(traceback.format_exc())
        print(f"❌ Error: {str(e)}")

if __name__ == "__main__":
    main()