| """
|
| CNN Deblurring Module - Deep Learning Based Image Enhancement
|
| ============================================================
|
|
|
| CNN inference system for image deblurring with TensorFlow/Keras.
|
| Includes model architecture, training utilities, and inference pipeline.
|
| """
|
|
|
| import cv2
|
| import numpy as np
|
| import os
|
| import logging
|
| from typing import Optional, Tuple, List
|
| import pickle
|
|
|
|
|
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
|
|
| import tensorflow as tf
|
| from tensorflow import keras
|
| from tensorflow.keras import layers, Model
|
|
|
|
|
| tf.get_logger().setLevel('ERROR')
|
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| class CNNDeblurModel:
|
| """CNN-based deblurring model with encoder-decoder architecture"""
|
|
|
| def __init__(self, input_shape: Tuple[int, int, int] = (256, 256, 3)):
|
| self.input_shape = input_shape
|
| self.model = None
|
| self.is_trained = False
|
| self.training_history = None
|
| self.model_path = "models/cnn_deblur_model.h5"
|
| self.dataset_path = "data/training_dataset"
|
|
|
| def build_model(self) -> Model:
|
| """
|
| Build CNN deblurring model with U-Net like architecture
|
|
|
| Returns:
|
| keras.Model: Compiled CNN model
|
| """
|
| try:
|
|
|
| inputs = keras.Input(shape=self.input_shape)
|
|
|
|
|
| conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
|
| conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
|
| pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
|
|
|
| conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
|
| conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
|
| pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
|
|
|
| conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
|
| conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
|
| pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
|
|
|
|
|
| conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
|
| conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
|
|
|
|
|
| up5 = layers.UpSampling2D(size=(2, 2))(conv4)
|
| up5 = layers.Conv2D(256, 2, activation='relu', padding='same')(up5)
|
| merge5 = layers.concatenate([conv3, up5], axis=3)
|
| conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(merge5)
|
| conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)
|
|
|
| up6 = layers.UpSampling2D(size=(2, 2))(conv5)
|
| up6 = layers.Conv2D(128, 2, activation='relu', padding='same')(up6)
|
| merge6 = layers.concatenate([conv2, up6], axis=3)
|
| conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(merge6)
|
| conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)
|
|
|
| up7 = layers.UpSampling2D(size=(2, 2))(conv6)
|
| up7 = layers.Conv2D(64, 2, activation='relu', padding='same')(up7)
|
| merge7 = layers.concatenate([conv1, up7], axis=3)
|
| conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(merge7)
|
| conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)
|
|
|
|
|
| outputs = layers.Conv2D(3, 1, activation='sigmoid')(conv7)
|
|
|
|
|
| model = Model(inputs=inputs, outputs=outputs)
|
|
|
|
|
| model.compile(
|
| optimizer='adam',
|
| loss='mse',
|
| metrics=['mae', 'mse']
|
| )
|
|
|
| self.model = model
|
| logger.info("CNN model built successfully")
|
| return model
|
|
|
| except Exception as e:
|
| logger.error(f"Error building CNN model: {e}")
|
| return None
|
|
|
| def load_model(self, model_path: str) -> bool:
|
| """
|
| Load pre-trained model from file
|
|
|
| Args:
|
| model_path: Path to saved model
|
|
|
| Returns:
|
| bool: Success status
|
| """
|
| try:
|
| if os.path.exists(model_path):
|
| self.model = keras.models.load_model(model_path)
|
| self.is_trained = True
|
| logger.info(f"Model loaded from {model_path}")
|
| return True
|
| else:
|
| logger.warning(f"Model file not found: {model_path}")
|
|
|
| self.build_model()
|
| return False
|
|
|
| except Exception as e:
|
| logger.error(f"Error loading model: {e}")
|
| self.build_model()
|
| return False
|
|
|
| def save_model(self, model_path: str) -> bool:
|
| """
|
| Save current model to file
|
|
|
| Args:
|
| model_path: Path to save model
|
|
|
| Returns:
|
| bool: Success status
|
| """
|
| try:
|
| if self.model is not None:
|
| self.model.save(model_path)
|
| logger.info(f"Model saved to {model_path}")
|
| return True
|
| else:
|
| logger.error("No model to save")
|
| return False
|
|
|
| except Exception as e:
|
| logger.error(f"Error saving model: {e}")
|
| return False
|
|
|
| def preprocess_image(self, image: np.ndarray) -> np.ndarray:
|
| """
|
| Preprocess image for CNN input with color preservation
|
|
|
| Args:
|
| image: Input image (BGR format)
|
|
|
| Returns:
|
| np.ndarray: Preprocessed image
|
| """
|
| try:
|
|
|
| if len(image.shape) == 3 and image.shape[2] == 3:
|
| rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| else:
|
| rgb_image = image
|
|
|
|
|
| resized = cv2.resize(rgb_image,
|
| (self.input_shape[1], self.input_shape[0]),
|
| interpolation=cv2.INTER_CUBIC)
|
|
|
|
|
| normalized = resized.astype(np.float64) / 255.0
|
|
|
|
|
| batched = np.expand_dims(normalized, axis=0)
|
|
|
| return batched.astype(np.float32)
|
|
|
| except Exception as e:
|
| logger.error(f"Error preprocessing image: {e}")
|
| return np.array([])
|
|
|
| def postprocess_image(self, output: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
|
| """
|
| Postprocess CNN output to original image format with color preservation
|
|
|
| Args:
|
| output: CNN model output
|
| original_shape: Original image shape (height, width)
|
|
|
| Returns:
|
| np.ndarray: Postprocessed image in BGR format
|
| """
|
| try:
|
|
|
| if len(output.shape) == 4:
|
| output = output[0]
|
|
|
|
|
| denormalized = np.clip(output * 255.0, 0, 255)
|
| denormalized = np.round(denormalized).astype(np.uint8)
|
|
|
|
|
| resized = cv2.resize(denormalized,
|
| (original_shape[1], original_shape[0]),
|
| interpolation=cv2.INTER_CUBIC)
|
|
|
|
|
| bgr_image = cv2.cvtColor(resized, cv2.COLOR_RGB2BGR)
|
|
|
| return bgr_image
|
|
|
| except Exception as e:
|
| logger.error(f"Error postprocessing image: {e}")
|
| return np.zeros((*original_shape, 3), dtype=np.uint8)
|
|
|
| def enhance_image(self, image: np.ndarray) -> np.ndarray:
|
| """
|
| Enhance image using CNN model
|
|
|
| Args:
|
| image: Input blurry image (BGR format)
|
|
|
| Returns:
|
| np.ndarray: Enhanced image (BGR format)
|
| """
|
| try:
|
| if self.model is None:
|
| logger.warning("No model available, building new model")
|
| self.build_model()
|
|
|
|
|
| original_shape = image.shape[:2]
|
|
|
|
|
| preprocessed = self.preprocess_image(image)
|
|
|
| if preprocessed.size == 0:
|
| logger.error("Failed to preprocess image")
|
| return image
|
|
|
|
|
| if not self.is_trained:
|
| logger.info("Using fallback enhancement (model not trained)")
|
| return self._fallback_enhancement(image)
|
|
|
|
|
| enhanced = self.model.predict(preprocessed, verbose=0)
|
|
|
|
|
| result = self.postprocess_image(enhanced, original_shape)
|
|
|
| logger.info("CNN enhancement completed")
|
| return result
|
|
|
| except Exception as e:
|
| logger.error(f"Error in CNN enhancement: {e}")
|
| return self._fallback_enhancement(image)
|
|
|
| def _fallback_enhancement(self, image: np.ndarray) -> np.ndarray:
|
| """
|
| Fallback enhancement when CNN model is not available - preserves original colors
|
|
|
| Args:
|
| image: Input image
|
|
|
| Returns:
|
| np.ndarray: Enhanced image using color-preserving traditional methods
|
| """
|
| try:
|
|
|
|
|
| gaussian = cv2.GaussianBlur(image, (5, 5), 1.0)
|
|
|
|
|
| enhanced = cv2.addWeighted(image, 1.2, gaussian, -0.2, 0)
|
|
|
|
|
|
|
| img_float = image.astype(np.float64)
|
|
|
|
|
| kernel_sharpen = np.array([[-0.1, -0.1, -0.1],
|
| [-0.1, 1.8, -0.1],
|
| [-0.1, -0.1, -0.1]])
|
|
|
|
|
| sharpened_channels = []
|
| for i in range(3):
|
| channel = img_float[:, :, i]
|
| sharpened_channel = cv2.filter2D(channel, -1, kernel_sharpen)
|
| sharpened_channels.append(sharpened_channel)
|
|
|
| sharpened = np.stack(sharpened_channels, axis=2)
|
|
|
|
|
| result = 0.7 * img_float + 0.3 * sharpened
|
|
|
|
|
| result = np.clip(result, 0, 255).astype(np.uint8)
|
|
|
| logger.info("Color-preserving fallback enhancement applied")
|
| return result
|
|
|
| except Exception as e:
|
| logger.error(f"Error in fallback enhancement: {e}")
|
| return image
|
|
|
| class CNNTrainer:
|
| """Training utilities for CNN deblurring model"""
|
|
|
| def __init__(self, model: CNNDeblurModel):
|
| self.model = model
|
|
|
| def create_synthetic_data(self, clean_images: List[np.ndarray],
|
| blur_types: List[str] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| """
|
| Create synthetic training data by applying blur to clean images
|
|
|
| Args:
|
| clean_images: List of clean images
|
| blur_types: Types of blur to apply
|
|
|
| Returns:
|
| tuple: (blurred_images, clean_images) for training
|
| """
|
| if blur_types is None:
|
| blur_types = ['gaussian', 'motion', 'defocus']
|
|
|
| blurred_batch = []
|
| clean_batch = []
|
|
|
| try:
|
| for clean_img in clean_images:
|
|
|
| blur_type = np.random.choice(blur_types)
|
|
|
| if blur_type == 'gaussian':
|
|
|
| kernel_size = np.random.randint(5, 15)
|
| if kernel_size % 2 == 0:
|
| kernel_size += 1
|
| blurred = cv2.GaussianBlur(clean_img, (kernel_size, kernel_size), 0)
|
|
|
| elif blur_type == 'motion':
|
|
|
| length = np.random.randint(5, 20)
|
| angle = np.random.randint(0, 180)
|
| kernel = self._create_motion_kernel(length, angle)
|
| blurred = cv2.filter2D(clean_img, -1, kernel)
|
|
|
| else:
|
|
|
| sigma = np.random.uniform(1, 5)
|
| blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
|
|
|
| blurred_batch.append(blurred)
|
| clean_batch.append(clean_img)
|
|
|
| return np.array(blurred_batch), np.array(clean_batch)
|
|
|
| except Exception as e:
|
| logger.error(f"Error creating synthetic data: {e}")
|
| return np.array([]), np.array([])
|
|
|
| def _create_motion_kernel(self, length: int, angle: float) -> np.ndarray:
|
| """Create motion blur kernel"""
|
| kernel = np.zeros((length, length))
|
| center = length // 2
|
|
|
| cos_val = np.cos(np.radians(angle))
|
| sin_val = np.sin(np.radians(angle))
|
|
|
| for i in range(length):
|
| offset = i - center
|
| y = int(center + offset * sin_val)
|
| x = int(center + offset * cos_val)
|
| if 0 <= y < length and 0 <= x < length:
|
| kernel[y, x] = 1
|
|
|
| return kernel / kernel.sum()
|
|
|
| def _load_user_images(self) -> List[np.ndarray]:
|
| """Load user's training images from training_dataset folder"""
|
| user_images = []
|
|
|
| try:
|
| if not os.path.exists(self.dataset_path):
|
| return user_images
|
|
|
|
|
| valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
|
|
|
| for filename in os.listdir(self.dataset_path):
|
| if any(filename.lower().endswith(ext) for ext in valid_extensions):
|
| image_path = os.path.join(self.dataset_path, filename)
|
| try:
|
|
|
| image = cv2.imread(image_path)
|
| if image is not None:
|
|
|
| resized = cv2.resize(image, (self.input_shape[1], self.input_shape[0]))
|
| user_images.append(resized)
|
| logger.info(f"Loaded user image: {filename}")
|
| except Exception as e:
|
| logger.warning(f"Failed to load {filename}: {e}")
|
|
|
| logger.info(f"Loaded {len(user_images)} user training images")
|
| return user_images
|
|
|
| except Exception as e:
|
| logger.error(f"Error loading user images: {e}")
|
| return []
|
|
|
| def create_training_dataset(self, num_samples: int = 1000, save_dataset: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
| """
|
| Create comprehensive training dataset with various blur types
|
| Incorporates user's real training images from data/training_dataset/
|
|
|
| Args:
|
| num_samples: Number of training samples to generate
|
| save_dataset: Whether to save dataset to disk
|
|
|
| Returns:
|
| Tuple[np.ndarray, np.ndarray]: Blurred images and clean targets
|
| """
|
| try:
|
| logger.info(f"Creating training dataset with {num_samples} samples...")
|
|
|
|
|
| os.makedirs(self.dataset_path, exist_ok=True)
|
|
|
|
|
| user_images = self._load_user_images()
|
|
|
| all_blurred = []
|
| all_clean = []
|
|
|
|
|
| if user_images:
|
| logger.info(f"Processing {len(user_images)} user training images...")
|
| for user_img in user_images:
|
|
|
| for _ in range(3):
|
| blur_type = np.random.choice(['gaussian', 'motion', 'defocus'])
|
|
|
| if blur_type == 'gaussian':
|
| sigma = np.random.uniform(0.5, 3.0)
|
| blurred = cv2.GaussianBlur(user_img, (0, 0), sigma)
|
| elif blur_type == 'motion':
|
| length = np.random.randint(5, 25)
|
| angle = np.random.randint(0, 180)
|
| kernel = self._create_motion_kernel(length, angle)
|
| blurred = cv2.filter2D(user_img, -1, kernel)
|
| else:
|
| sigma = np.random.uniform(1.0, 4.0)
|
| blurred = cv2.GaussianBlur(user_img, (0, 0), sigma)
|
|
|
|
|
| noise = np.random.normal(0, 3, blurred.shape).astype(np.float32)
|
| blurred = np.clip(blurred.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
|
|
| all_blurred.append(blurred)
|
| all_clean.append(user_img)
|
|
|
|
|
| remaining_samples = max(0, num_samples - len(all_blurred))
|
| if remaining_samples > 0:
|
| logger.info(f"Generating {remaining_samples} synthetic training samples...")
|
|
|
| batch_size = 50
|
| num_batches = (remaining_samples + batch_size - 1) // batch_size
|
|
|
| for batch_idx in range(num_batches):
|
| current_batch_size = min(batch_size, remaining_samples - batch_idx * batch_size)
|
|
|
|
|
| clean_batch = self._generate_clean_images(current_batch_size)
|
|
|
|
|
| blurred_batch = []
|
| for clean_img in clean_batch:
|
| blur_type = np.random.choice(['gaussian', 'motion', 'defocus'])
|
|
|
| if blur_type == 'gaussian':
|
| sigma = np.random.uniform(0.5, 3.0)
|
| blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
|
| elif blur_type == 'motion':
|
| length = np.random.randint(5, 25)
|
| angle = np.random.randint(0, 180)
|
| kernel = self._create_motion_kernel(length, angle)
|
| blurred = cv2.filter2D(clean_img, -1, kernel)
|
| else:
|
| sigma = np.random.uniform(1.0, 4.0)
|
| blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
|
|
|
|
|
| noise = np.random.normal(0, 5, blurred.shape).astype(np.float32)
|
| blurred = np.clip(blurred.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
|
|
| blurred_batch.append(blurred)
|
|
|
| all_blurred.extend(blurred_batch)
|
| all_clean.extend(clean_batch)
|
|
|
| if (batch_idx + 1) % 5 == 0:
|
| logger.info(f"Generated batch {batch_idx + 1}/{num_batches}")
|
|
|
|
|
| blurred_dataset = np.array(all_blurred)
|
| clean_dataset = np.array(all_clean)
|
|
|
|
|
| blurred_dataset = blurred_dataset.astype(np.float32) / 255.0
|
| clean_dataset = clean_dataset.astype(np.float32) / 255.0
|
|
|
| logger.info(f"Dataset created: {blurred_dataset.shape} blurred, {clean_dataset.shape} clean")
|
|
|
|
|
| if save_dataset:
|
| np.save(os.path.join(self.dataset_path, 'blurred_images.npy'), blurred_dataset)
|
| np.save(os.path.join(self.dataset_path, 'clean_images.npy'), clean_dataset)
|
| logger.info(f"Dataset saved to {self.dataset_path}")
|
|
|
| return blurred_dataset, clean_dataset
|
|
|
| except Exception as e:
|
| logger.error(f"Error creating training dataset: {e}")
|
| return np.array([]), np.array([])
|
|
|
| def _generate_clean_images(self, num_images: int) -> List[np.ndarray]:
|
| """Generate synthetic clean images for training"""
|
| clean_images = []
|
|
|
| for _ in range(num_images):
|
|
|
| img = np.zeros((self.input_shape[0], self.input_shape[1], 3), dtype=np.uint8)
|
|
|
|
|
| bg_color = np.random.randint(0, 255, 3)
|
| img[:] = bg_color
|
|
|
|
|
| num_shapes = np.random.randint(3, 8)
|
| for _ in range(num_shapes):
|
| shape_type = np.random.choice(['rectangle', 'circle', 'line'])
|
| color = np.random.randint(0, 255, 3).tolist()
|
|
|
| if shape_type == 'rectangle':
|
| pt1 = (np.random.randint(0, img.shape[1]//2), np.random.randint(0, img.shape[0]//2))
|
| pt2 = (np.random.randint(img.shape[1]//2, img.shape[1]),
|
| np.random.randint(img.shape[0]//2, img.shape[0]))
|
| cv2.rectangle(img, pt1, pt2, color, -1)
|
|
|
| elif shape_type == 'circle':
|
| center = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
|
| radius = np.random.randint(10, 50)
|
| cv2.circle(img, center, radius, color, -1)
|
|
|
| else:
|
| pt1 = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
|
| pt2 = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
|
| thickness = np.random.randint(1, 5)
|
| cv2.line(img, pt1, pt2, color, thickness)
|
|
|
|
|
| if np.random.random() > 0.5:
|
| text = ''.join(np.random.choice(list('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'),
|
| np.random.randint(3, 8)))
|
| font = cv2.FONT_HERSHEY_SIMPLEX
|
| font_scale = np.random.uniform(0.5, 2.0)
|
| color = np.random.randint(0, 255, 3).tolist()
|
| thickness = np.random.randint(1, 3)
|
| position = (np.random.randint(0, img.shape[1]//2), np.random.randint(20, img.shape[0]))
|
| cv2.putText(img, text, position, font, font_scale, color, thickness)
|
|
|
| clean_images.append(img)
|
|
|
| return clean_images
|
|
|
| def load_existing_dataset(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
| """Load existing dataset from disk"""
|
| try:
|
| blurred_path = os.path.join(self.dataset_path, 'blurred_images.npy')
|
| clean_path = os.path.join(self.dataset_path, 'clean_images.npy')
|
|
|
| if os.path.exists(blurred_path) and os.path.exists(clean_path):
|
| blurred_data = np.load(blurred_path)
|
| clean_data = np.load(clean_path)
|
| logger.info(f"Loaded existing dataset: {blurred_data.shape} samples")
|
| return blurred_data, clean_data
|
| else:
|
| logger.info("No existing dataset found")
|
| return None, None
|
|
|
| except Exception as e:
|
| logger.error(f"Error loading existing dataset: {e}")
|
| return None, None
|
|
|
| def train_model(self,
|
| epochs: int = 20,
|
| batch_size: int = 16,
|
| validation_split: float = 0.2,
|
| use_existing_dataset: bool = True,
|
| num_training_samples: int = 1000) -> bool:
|
| """
|
| Train the CNN model with comprehensive dataset
|
|
|
| Args:
|
| epochs: Number of training epochs
|
| batch_size: Training batch size
|
| validation_split: Fraction of data for validation
|
| use_existing_dataset: Whether to use existing saved dataset
|
| num_training_samples: Number of samples to generate if creating new dataset
|
|
|
| Returns:
|
| bool: Training success status
|
| """
|
| try:
|
| logger.info("Starting CNN model training...")
|
|
|
|
|
| if self.model is None:
|
| self.build_model()
|
|
|
|
|
| if use_existing_dataset:
|
| blurred_data, clean_data = self.load_existing_dataset()
|
| if blurred_data is None:
|
| logger.info("Creating new dataset...")
|
| blurred_data, clean_data = self.create_training_dataset(num_training_samples)
|
| else:
|
| logger.info("Creating new dataset...")
|
| blurred_data, clean_data = self.create_training_dataset(num_training_samples)
|
|
|
| if len(blurred_data) == 0:
|
| logger.error("Failed to create/load training dataset")
|
| return False
|
|
|
| logger.info(f"Training on {len(blurred_data)} samples")
|
|
|
|
|
| callbacks = [
|
| keras.callbacks.EarlyStopping(
|
| monitor='val_loss',
|
| patience=5,
|
| restore_best_weights=True
|
| ),
|
| keras.callbacks.ReduceLROnPlateau(
|
| monitor='val_loss',
|
| factor=0.5,
|
| patience=3,
|
| min_lr=1e-7
|
| ),
|
| keras.callbacks.ModelCheckpoint(
|
| filepath=self.model_path,
|
| monitor='val_loss',
|
| save_best_only=True,
|
| save_weights_only=False
|
| )
|
| ]
|
|
|
|
|
| self.training_history = self.model.fit(
|
| blurred_data, clean_data,
|
| epochs=epochs,
|
| batch_size=batch_size,
|
| validation_split=validation_split,
|
| callbacks=callbacks,
|
| verbose=1
|
| )
|
|
|
|
|
| self.save_model(self.model_path)
|
| self.is_trained = True
|
|
|
|
|
| history_path = self.model_path.replace('.h5', '_history.pkl')
|
| with open(history_path, 'wb') as f:
|
| pickle.dump(self.training_history.history, f)
|
|
|
| logger.info("Training completed successfully!")
|
| logger.info(f"Model saved to: {self.model_path}")
|
|
|
|
|
| final_loss = self.training_history.history['loss'][-1]
|
| final_val_loss = self.training_history.history['val_loss'][-1]
|
| logger.info(f"Final training loss: {final_loss:.4f}")
|
| logger.info(f"Final validation loss: {final_val_loss:.4f}")
|
|
|
| return True
|
|
|
| except Exception as e:
|
| logger.error(f"Error during training: {e}")
|
| return False
|
|
|
| def evaluate_model(self, test_images: np.ndarray = None, test_targets: np.ndarray = None) -> dict:
|
| """
|
| Evaluate model performance on test data
|
|
|
| Args:
|
| test_images: Test images (if None, creates synthetic test set)
|
| test_targets: Test targets (if None, creates synthetic test set)
|
|
|
| Returns:
|
| dict: Evaluation metrics
|
| """
|
| try:
|
| if self.model is None or not self.is_trained:
|
| logger.error("Model not trained. Train the model first.")
|
| return {}
|
|
|
|
|
| if test_images is None or test_targets is None:
|
| logger.info("Creating test dataset...")
|
| test_images, test_targets = self.create_training_dataset(num_samples=100, save_dataset=False)
|
|
|
|
|
| results = self.model.evaluate(test_images, test_targets, verbose=0)
|
|
|
| metrics = {
|
| 'loss': results[0],
|
| 'mae': results[1],
|
| 'mse': results[2]
|
| }
|
|
|
| logger.info("Model Evaluation Results:")
|
| for metric, value in metrics.items():
|
| logger.info(f" {metric}: {value:.4f}")
|
|
|
| return metrics
|
|
|
| except Exception as e:
|
| logger.error(f"Error during evaluation: {e}")
|
| return {}
|
|
|
|
|
| def load_cnn_model(model_path: str = "models/cnn_model.h5") -> CNNDeblurModel:
|
| """
|
| Load CNN deblurring model
|
|
|
| Args:
|
| model_path: Path to model file
|
|
|
| Returns:
|
| CNNDeblurModel: Loaded model instance
|
| """
|
| model = CNNDeblurModel()
|
| model.load_model(model_path)
|
| return model
|
|
|
| def enhance_with_cnn(image: np.ndarray, model_path: str = "models/cnn_model.h5") -> np.ndarray:
|
| """
|
| Enhance image using CNN model
|
|
|
| Args:
|
| image: Input image
|
| model_path: Path to model file
|
|
|
| Returns:
|
| np.ndarray: Enhanced image
|
| """
|
| model = load_cnn_model(model_path)
|
| return model.enhance_image(image)
|
|
|
|
|
| def train_new_model(num_samples: int = 1000, epochs: int = 20, input_shape: Tuple[int, int, int] = (256, 256, 3)):
|
| """
|
| Train a new CNN deblurring model from scratch
|
|
|
| Args:
|
| num_samples: Number of training samples to generate
|
| epochs: Number of training epochs
|
| input_shape: Input image shape
|
|
|
| Returns:
|
| CNNDeblurModel: Trained model
|
| """
|
| print("π Training New CNN Deblurring Model")
|
| print("=" * 50)
|
|
|
|
|
| os.makedirs("models", exist_ok=True)
|
| os.makedirs("data/training_dataset", exist_ok=True)
|
|
|
|
|
| model = CNNDeblurModel(input_shape=input_shape)
|
|
|
|
|
| success = model.train_model(
|
| epochs=epochs,
|
| batch_size=16,
|
| validation_split=0.2,
|
| use_existing_dataset=True,
|
| num_training_samples=num_samples
|
| )
|
|
|
| if success:
|
| print("β
Training completed successfully!")
|
|
|
|
|
| metrics = model.evaluate_model()
|
| if metrics:
|
| print(f"π Model Performance:")
|
| print(f" Loss: {metrics['loss']:.4f}")
|
| print(f" MAE: {metrics['mae']:.4f}")
|
| print(f" MSE: {metrics['mse']:.4f}")
|
|
|
| return model
|
| else:
|
| print("β Training failed!")
|
| return None
|
|
|
| def quick_train():
|
| """Quick training with default parameters"""
|
| return train_new_model(num_samples=500, epochs=10)
|
|
|
| def full_train():
|
| """Full training with comprehensive dataset"""
|
| return train_new_model(num_samples=2000, epochs=30)
|
|
|
|
|
| if __name__ == "__main__":
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser(description='CNN Deblurring Module')
|
| parser.add_argument('--train', action='store_true', help='Train the model')
|
| parser.add_argument('--quick-train', action='store_true', help='Quick training (500 samples, 10 epochs)')
|
| parser.add_argument('--full-train', action='store_true', help='Full training (2000 samples, 30 epochs)')
|
| parser.add_argument('--samples', type=int, default=1000, help='Number of training samples')
|
| parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
|
| parser.add_argument('--test', action='store_true', help='Test the model')
|
|
|
| args = parser.parse_args()
|
|
|
| print("π― CNN Deblurring Module")
|
| print("=" * 30)
|
|
|
| if args.quick_train:
|
| print("π Quick Training Mode")
|
| model = quick_train()
|
|
|
| elif args.full_train:
|
| print("π Full Training Mode")
|
| model = full_train()
|
|
|
| elif args.train:
|
| print(f"π Custom Training Mode")
|
| model = train_new_model(num_samples=args.samples, epochs=args.epochs)
|
|
|
| elif args.test:
|
| print("π§ͺ Testing Mode")
|
|
|
|
|
| test_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
|
|
|
|
| cnn_model = CNNDeblurModel()
|
|
|
|
|
| if cnn_model.load_model(cnn_model.model_path):
|
| print(f"β
Loaded existing trained model")
|
| else:
|
| print(f"βΉοΈ No trained model found, building new model")
|
| cnn_model.build_model()
|
|
|
| print(f"Model input shape: {cnn_model.input_shape}")
|
| print(f"Model built: {cnn_model.model is not None}")
|
| print(f"Model trained: {cnn_model.is_trained}")
|
|
|
|
|
| enhanced = cnn_model.enhance_image(test_image)
|
| print(f"Original shape: {test_image.shape}")
|
| print(f"Enhanced shape: {enhanced.shape}")
|
|
|
| if cnn_model.is_trained:
|
|
|
| metrics = cnn_model.evaluate_model()
|
| if metrics:
|
| print("π Model Performance:")
|
| for metric, value in metrics.items():
|
| print(f" {metric}: {value:.4f}")
|
|
|
| else:
|
| print("βΉοΈ Usage options:")
|
| print(" --test Test existing model or build new one")
|
| print(" --quick-train Quick training (500 samples, 10 epochs)")
|
| print(" --full-train Full training (2000 samples, 30 epochs)")
|
| print(" --train Custom training (use --samples and --epochs)")
|
| print("\nExamples:")
|
| print(" python -m modules.cnn_deblurring --test")
|
| print(" python -m modules.cnn_deblurring --quick-train")
|
| print(" python -m modules.cnn_deblurring --train --samples 1500 --epochs 25")
|
|
|
| print("\nπ― CNN deblurring module ready!") |