ganeshkumar383's picture
Upload 27 files (#2)
ecc16d3 verified
"""
CNN Deblurring Module - Deep Learning Based Image Enhancement
============================================================
CNN inference system for image deblurring with TensorFlow/Keras.
Includes model architecture, training utilities, and inference pipeline.
"""
import cv2
import numpy as np
import os
import logging
from typing import Optional, Tuple, List
import pickle
# Configure TensorFlow to reduce verbosity
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TensorFlow logging
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Disable oneDNN messages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
# Configure TensorFlow settings
tf.get_logger().setLevel('ERROR') # Only show errors
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CNNDeblurModel:
"""CNN-based deblurring model with encoder-decoder architecture"""
def __init__(self, input_shape: Tuple[int, int, int] = (256, 256, 3)):
self.input_shape = input_shape
self.model = None
self.is_trained = False
self.training_history = None
self.model_path = "models/cnn_deblur_model.h5"
self.dataset_path = "data/training_dataset"
def build_model(self) -> Model:
"""
Build CNN deblurring model with U-Net like architecture
Returns:
keras.Model: Compiled CNN model
"""
try:
# Input layer
inputs = keras.Input(shape=self.input_shape)
# Encoder (Downsampling)
conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
# Bottleneck
conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
# Decoder (Upsampling)
up5 = layers.UpSampling2D(size=(2, 2))(conv4)
up5 = layers.Conv2D(256, 2, activation='relu', padding='same')(up5)
merge5 = layers.concatenate([conv3, up5], axis=3)
conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(merge5)
conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)
up6 = layers.UpSampling2D(size=(2, 2))(conv5)
up6 = layers.Conv2D(128, 2, activation='relu', padding='same')(up6)
merge6 = layers.concatenate([conv2, up6], axis=3)
conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(merge6)
conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)
up7 = layers.UpSampling2D(size=(2, 2))(conv6)
up7 = layers.Conv2D(64, 2, activation='relu', padding='same')(up7)
merge7 = layers.concatenate([conv1, up7], axis=3)
conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(merge7)
conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)
# Output layer
outputs = layers.Conv2D(3, 1, activation='sigmoid')(conv7)
# Create model
model = Model(inputs=inputs, outputs=outputs)
# Compile model
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae', 'mse']
)
self.model = model
logger.info("CNN model built successfully")
return model
except Exception as e:
logger.error(f"Error building CNN model: {e}")
return None
def load_model(self, model_path: str) -> bool:
"""
Load pre-trained model from file
Args:
model_path: Path to saved model
Returns:
bool: Success status
"""
try:
if os.path.exists(model_path):
self.model = keras.models.load_model(model_path)
self.is_trained = True
logger.info(f"Model loaded from {model_path}")
return True
else:
logger.warning(f"Model file not found: {model_path}")
# Build new model as fallback
self.build_model()
return False
except Exception as e:
logger.error(f"Error loading model: {e}")
self.build_model() # Fallback to new model
return False
def save_model(self, model_path: str) -> bool:
"""
Save current model to file
Args:
model_path: Path to save model
Returns:
bool: Success status
"""
try:
if self.model is not None:
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
return True
else:
logger.error("No model to save")
return False
except Exception as e:
logger.error(f"Error saving model: {e}")
return False
def preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess image for CNN input with color preservation
Args:
image: Input image (BGR format)
Returns:
np.ndarray: Preprocessed image
"""
try:
# Convert BGR to RGB (preserve original precision)
if len(image.shape) == 3 and image.shape[2] == 3:
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
rgb_image = image
# Resize to model input size with high-quality interpolation
resized = cv2.resize(rgb_image,
(self.input_shape[1], self.input_shape[0]),
interpolation=cv2.INTER_CUBIC) # Better color preservation
# Normalize to [0, 1] with high precision
normalized = resized.astype(np.float64) / 255.0 # Use float64 for precision
# Add batch dimension
batched = np.expand_dims(normalized, axis=0)
return batched.astype(np.float32) # Convert to float32 for model
except Exception as e:
logger.error(f"Error preprocessing image: {e}")
return np.array([])
def postprocess_image(self, output: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
"""
Postprocess CNN output to original image format with color preservation
Args:
output: CNN model output
original_shape: Original image shape (height, width)
Returns:
np.ndarray: Postprocessed image in BGR format
"""
try:
# Remove batch dimension
if len(output.shape) == 4:
output = output[0]
# Denormalize from [0, 1] to [0, 255] with high precision
denormalized = np.clip(output * 255.0, 0, 255) # Clip before conversion
denormalized = np.round(denormalized).astype(np.uint8) # Round to preserve colors
# Resize to original size with high-quality interpolation
resized = cv2.resize(denormalized,
(original_shape[1], original_shape[0]),
interpolation=cv2.INTER_CUBIC) # Better color preservation
# Convert RGB back to BGR
bgr_image = cv2.cvtColor(resized, cv2.COLOR_RGB2BGR)
return bgr_image
except Exception as e:
logger.error(f"Error postprocessing image: {e}")
return np.zeros((*original_shape, 3), dtype=np.uint8)
def enhance_image(self, image: np.ndarray) -> np.ndarray:
"""
Enhance image using CNN model
Args:
image: Input blurry image (BGR format)
Returns:
np.ndarray: Enhanced image (BGR format)
"""
try:
if self.model is None:
logger.warning("No model available, building new model")
self.build_model()
# Store original shape
original_shape = image.shape[:2]
# Preprocess
preprocessed = self.preprocess_image(image)
if preprocessed.size == 0:
logger.error("Failed to preprocess image")
return image
# If model is not trained, return enhanced version using traditional methods
if not self.is_trained:
logger.info("Using fallback enhancement (model not trained)")
return self._fallback_enhancement(image)
# CNN inference
enhanced = self.model.predict(preprocessed, verbose=0)
# Postprocess
result = self.postprocess_image(enhanced, original_shape)
logger.info("CNN enhancement completed")
return result
except Exception as e:
logger.error(f"Error in CNN enhancement: {e}")
return self._fallback_enhancement(image)
def _fallback_enhancement(self, image: np.ndarray) -> np.ndarray:
"""
Fallback enhancement when CNN model is not available - preserves original colors
Args:
image: Input image
Returns:
np.ndarray: Enhanced image using color-preserving traditional methods
"""
try:
# Method 1: Gentle unsharp masking with color preservation
# Create a subtle blur for unsharp masking
gaussian = cv2.GaussianBlur(image, (5, 5), 1.0)
# Apply very gentle unsharp masking to avoid color shifts
enhanced = cv2.addWeighted(image, 1.2, gaussian, -0.2, 0)
# Method 2: Enhance sharpness without changing colors
# Convert to float for precision
img_float = image.astype(np.float64)
# Apply high-pass filter for sharpening
kernel_sharpen = np.array([[-0.1, -0.1, -0.1],
[-0.1, 1.8, -0.1],
[-0.1, -0.1, -0.1]])
# Apply sharpening kernel to each channel separately
sharpened_channels = []
for i in range(3): # Process each color channel
channel = img_float[:, :, i]
sharpened_channel = cv2.filter2D(channel, -1, kernel_sharpen)
sharpened_channels.append(sharpened_channel)
sharpened = np.stack(sharpened_channels, axis=2)
# Combine original with sharpened (gentle blend)
result = 0.7 * img_float + 0.3 * sharpened
# Carefully clip and convert back
result = np.clip(result, 0, 255).astype(np.uint8)
logger.info("Color-preserving fallback enhancement applied")
return result
except Exception as e:
logger.error(f"Error in fallback enhancement: {e}")
return image
class CNNTrainer:
"""Training utilities for CNN deblurring model"""
def __init__(self, model: CNNDeblurModel):
self.model = model
def create_synthetic_data(self, clean_images: List[np.ndarray],
blur_types: List[str] = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Create synthetic training data by applying blur to clean images
Args:
clean_images: List of clean images
blur_types: Types of blur to apply
Returns:
tuple: (blurred_images, clean_images) for training
"""
if blur_types is None:
blur_types = ['gaussian', 'motion', 'defocus']
blurred_batch = []
clean_batch = []
try:
for clean_img in clean_images:
# Random blur type
blur_type = np.random.choice(blur_types)
if blur_type == 'gaussian':
# Gaussian blur
kernel_size = np.random.randint(5, 15)
if kernel_size % 2 == 0:
kernel_size += 1
blurred = cv2.GaussianBlur(clean_img, (kernel_size, kernel_size), 0)
elif blur_type == 'motion':
# Motion blur
length = np.random.randint(5, 20)
angle = np.random.randint(0, 180)
kernel = self._create_motion_kernel(length, angle)
blurred = cv2.filter2D(clean_img, -1, kernel)
else: # defocus
# Defocus blur (approximated with Gaussian)
sigma = np.random.uniform(1, 5)
blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
blurred_batch.append(blurred)
clean_batch.append(clean_img)
return np.array(blurred_batch), np.array(clean_batch)
except Exception as e:
logger.error(f"Error creating synthetic data: {e}")
return np.array([]), np.array([])
def _create_motion_kernel(self, length: int, angle: float) -> np.ndarray:
"""Create motion blur kernel"""
kernel = np.zeros((length, length))
center = length // 2
cos_val = np.cos(np.radians(angle))
sin_val = np.sin(np.radians(angle))
for i in range(length):
offset = i - center
y = int(center + offset * sin_val)
x = int(center + offset * cos_val)
if 0 <= y < length and 0 <= x < length:
kernel[y, x] = 1
return kernel / kernel.sum()
def _load_user_images(self) -> List[np.ndarray]:
"""Load user's training images from training_dataset folder"""
user_images = []
try:
if not os.path.exists(self.dataset_path):
return user_images
# Supported image extensions
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
for filename in os.listdir(self.dataset_path):
if any(filename.lower().endswith(ext) for ext in valid_extensions):
image_path = os.path.join(self.dataset_path, filename)
try:
# Load image
image = cv2.imread(image_path)
if image is not None:
# Resize to model input size
resized = cv2.resize(image, (self.input_shape[1], self.input_shape[0]))
user_images.append(resized)
logger.info(f"Loaded user image: {filename}")
except Exception as e:
logger.warning(f"Failed to load {filename}: {e}")
logger.info(f"Loaded {len(user_images)} user training images")
return user_images
except Exception as e:
logger.error(f"Error loading user images: {e}")
return []
def create_training_dataset(self, num_samples: int = 1000, save_dataset: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
Create comprehensive training dataset with various blur types
Incorporates user's real training images from data/training_dataset/
Args:
num_samples: Number of training samples to generate
save_dataset: Whether to save dataset to disk
Returns:
Tuple[np.ndarray, np.ndarray]: Blurred images and clean targets
"""
try:
logger.info(f"Creating training dataset with {num_samples} samples...")
# Ensure dataset directory exists
os.makedirs(self.dataset_path, exist_ok=True)
# Load user's training images
user_images = self._load_user_images()
all_blurred = []
all_clean = []
# First, process user images if available
if user_images:
logger.info(f"Processing {len(user_images)} user training images...")
for user_img in user_images:
# Use user image as clean target multiple times with different blur types
for _ in range(3): # Create 3 variations per user image
blur_type = np.random.choice(['gaussian', 'motion', 'defocus'])
if blur_type == 'gaussian':
sigma = np.random.uniform(0.5, 3.0)
blurred = cv2.GaussianBlur(user_img, (0, 0), sigma)
elif blur_type == 'motion':
length = np.random.randint(5, 25)
angle = np.random.randint(0, 180)
kernel = self._create_motion_kernel(length, angle)
blurred = cv2.filter2D(user_img, -1, kernel)
else: # defocus
sigma = np.random.uniform(1.0, 4.0)
blurred = cv2.GaussianBlur(user_img, (0, 0), sigma)
# Add slight noise for realism
noise = np.random.normal(0, 3, blurred.shape).astype(np.float32)
blurred = np.clip(blurred.astype(np.float32) + noise, 0, 255).astype(np.uint8)
all_blurred.append(blurred)
all_clean.append(user_img)
# Generate remaining samples with synthetic images
remaining_samples = max(0, num_samples - len(all_blurred))
if remaining_samples > 0:
logger.info(f"Generating {remaining_samples} synthetic training samples...")
batch_size = 50
num_batches = (remaining_samples + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
current_batch_size = min(batch_size, remaining_samples - batch_idx * batch_size)
# Create synthetic clean images
clean_batch = self._generate_clean_images(current_batch_size)
# Apply various blur types
blurred_batch = []
for clean_img in clean_batch:
blur_type = np.random.choice(['gaussian', 'motion', 'defocus'])
if blur_type == 'gaussian':
sigma = np.random.uniform(0.5, 3.0)
blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
elif blur_type == 'motion':
length = np.random.randint(5, 25)
angle = np.random.randint(0, 180)
kernel = self._create_motion_kernel(length, angle)
blurred = cv2.filter2D(clean_img, -1, kernel)
else: # defocus
sigma = np.random.uniform(1.0, 4.0)
blurred = cv2.GaussianBlur(clean_img, (0, 0), sigma)
# Add slight noise for realism
noise = np.random.normal(0, 5, blurred.shape).astype(np.float32)
blurred = np.clip(blurred.astype(np.float32) + noise, 0, 255).astype(np.uint8)
blurred_batch.append(blurred)
all_blurred.extend(blurred_batch)
all_clean.extend(clean_batch)
if (batch_idx + 1) % 5 == 0:
logger.info(f"Generated batch {batch_idx + 1}/{num_batches}")
# Convert to numpy arrays
blurred_dataset = np.array(all_blurred)
clean_dataset = np.array(all_clean)
# Normalize to [0, 1]
blurred_dataset = blurred_dataset.astype(np.float32) / 255.0
clean_dataset = clean_dataset.astype(np.float32) / 255.0
logger.info(f"Dataset created: {blurred_dataset.shape} blurred, {clean_dataset.shape} clean")
# Save dataset if requested
if save_dataset:
np.save(os.path.join(self.dataset_path, 'blurred_images.npy'), blurred_dataset)
np.save(os.path.join(self.dataset_path, 'clean_images.npy'), clean_dataset)
logger.info(f"Dataset saved to {self.dataset_path}")
return blurred_dataset, clean_dataset
except Exception as e:
logger.error(f"Error creating training dataset: {e}")
return np.array([]), np.array([])
def _generate_clean_images(self, num_images: int) -> List[np.ndarray]:
"""Generate synthetic clean images for training"""
clean_images = []
for _ in range(num_images):
# Create random patterns and shapes
img = np.zeros((self.input_shape[0], self.input_shape[1], 3), dtype=np.uint8)
# Random background
bg_color = np.random.randint(0, 255, 3)
img[:] = bg_color
# Add random shapes
num_shapes = np.random.randint(3, 8)
for _ in range(num_shapes):
shape_type = np.random.choice(['rectangle', 'circle', 'line'])
color = np.random.randint(0, 255, 3).tolist()
if shape_type == 'rectangle':
pt1 = (np.random.randint(0, img.shape[1]//2), np.random.randint(0, img.shape[0]//2))
pt2 = (np.random.randint(img.shape[1]//2, img.shape[1]),
np.random.randint(img.shape[0]//2, img.shape[0]))
cv2.rectangle(img, pt1, pt2, color, -1)
elif shape_type == 'circle':
center = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
radius = np.random.randint(10, 50)
cv2.circle(img, center, radius, color, -1)
else: # line
pt1 = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
pt2 = (np.random.randint(0, img.shape[1]), np.random.randint(0, img.shape[0]))
thickness = np.random.randint(1, 5)
cv2.line(img, pt1, pt2, color, thickness)
# Add random text
if np.random.random() > 0.5:
text = ''.join(np.random.choice(list('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'),
np.random.randint(3, 8)))
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = np.random.uniform(0.5, 2.0)
color = np.random.randint(0, 255, 3).tolist()
thickness = np.random.randint(1, 3)
position = (np.random.randint(0, img.shape[1]//2), np.random.randint(20, img.shape[0]))
cv2.putText(img, text, position, font, font_scale, color, thickness)
clean_images.append(img)
return clean_images
def load_existing_dataset(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Load existing dataset from disk"""
try:
blurred_path = os.path.join(self.dataset_path, 'blurred_images.npy')
clean_path = os.path.join(self.dataset_path, 'clean_images.npy')
if os.path.exists(blurred_path) and os.path.exists(clean_path):
blurred_data = np.load(blurred_path)
clean_data = np.load(clean_path)
logger.info(f"Loaded existing dataset: {blurred_data.shape} samples")
return blurred_data, clean_data
else:
logger.info("No existing dataset found")
return None, None
except Exception as e:
logger.error(f"Error loading existing dataset: {e}")
return None, None
def train_model(self,
epochs: int = 20,
batch_size: int = 16,
validation_split: float = 0.2,
use_existing_dataset: bool = True,
num_training_samples: int = 1000) -> bool:
"""
Train the CNN model with comprehensive dataset
Args:
epochs: Number of training epochs
batch_size: Training batch size
validation_split: Fraction of data for validation
use_existing_dataset: Whether to use existing saved dataset
num_training_samples: Number of samples to generate if creating new dataset
Returns:
bool: Training success status
"""
try:
logger.info("Starting CNN model training...")
# Build model if not exists
if self.model is None:
self.build_model()
# Load or create dataset
if use_existing_dataset:
blurred_data, clean_data = self.load_existing_dataset()
if blurred_data is None:
logger.info("Creating new dataset...")
blurred_data, clean_data = self.create_training_dataset(num_training_samples)
else:
logger.info("Creating new dataset...")
blurred_data, clean_data = self.create_training_dataset(num_training_samples)
if len(blurred_data) == 0:
logger.error("Failed to create/load training dataset")
return False
logger.info(f"Training on {len(blurred_data)} samples")
# Setup callbacks
callbacks = [
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-7
),
keras.callbacks.ModelCheckpoint(
filepath=self.model_path,
monitor='val_loss',
save_best_only=True,
save_weights_only=False
)
]
# Train model
self.training_history = self.model.fit(
blurred_data, clean_data,
epochs=epochs,
batch_size=batch_size,
validation_split=validation_split,
callbacks=callbacks,
verbose=1
)
# Save final model
self.save_model(self.model_path)
self.is_trained = True
# Save training history
history_path = self.model_path.replace('.h5', '_history.pkl')
with open(history_path, 'wb') as f:
pickle.dump(self.training_history.history, f)
logger.info("Training completed successfully!")
logger.info(f"Model saved to: {self.model_path}")
# Print training summary
final_loss = self.training_history.history['loss'][-1]
final_val_loss = self.training_history.history['val_loss'][-1]
logger.info(f"Final training loss: {final_loss:.4f}")
logger.info(f"Final validation loss: {final_val_loss:.4f}")
return True
except Exception as e:
logger.error(f"Error during training: {e}")
return False
def evaluate_model(self, test_images: np.ndarray = None, test_targets: np.ndarray = None) -> dict:
"""
Evaluate model performance on test data
Args:
test_images: Test images (if None, creates synthetic test set)
test_targets: Test targets (if None, creates synthetic test set)
Returns:
dict: Evaluation metrics
"""
try:
if self.model is None or not self.is_trained:
logger.error("Model not trained. Train the model first.")
return {}
# Create test data if not provided
if test_images is None or test_targets is None:
logger.info("Creating test dataset...")
test_images, test_targets = self.create_training_dataset(num_samples=100, save_dataset=False)
# Evaluate
results = self.model.evaluate(test_images, test_targets, verbose=0)
metrics = {
'loss': results[0],
'mae': results[1],
'mse': results[2]
}
logger.info("Model Evaluation Results:")
for metric, value in metrics.items():
logger.info(f" {metric}: {value:.4f}")
return metrics
except Exception as e:
logger.error(f"Error during evaluation: {e}")
return {}
# Convenience functions
def load_cnn_model(model_path: str = "models/cnn_model.h5") -> CNNDeblurModel:
"""
Load CNN deblurring model
Args:
model_path: Path to model file
Returns:
CNNDeblurModel: Loaded model instance
"""
model = CNNDeblurModel()
model.load_model(model_path)
return model
def enhance_with_cnn(image: np.ndarray, model_path: str = "models/cnn_model.h5") -> np.ndarray:
"""
Enhance image using CNN model
Args:
image: Input image
model_path: Path to model file
Returns:
np.ndarray: Enhanced image
"""
model = load_cnn_model(model_path)
return model.enhance_image(image)
# Training utility functions
def train_new_model(num_samples: int = 1000, epochs: int = 20, input_shape: Tuple[int, int, int] = (256, 256, 3)):
"""
Train a new CNN deblurring model from scratch
Args:
num_samples: Number of training samples to generate
epochs: Number of training epochs
input_shape: Input image shape
Returns:
CNNDeblurModel: Trained model
"""
print("πŸš€ Training New CNN Deblurring Model")
print("=" * 50)
# Ensure directories exist
os.makedirs("models", exist_ok=True)
os.makedirs("data/training_dataset", exist_ok=True)
# Initialize model
model = CNNDeblurModel(input_shape=input_shape)
# Train model
success = model.train_model(
epochs=epochs,
batch_size=16,
validation_split=0.2,
use_existing_dataset=True,
num_training_samples=num_samples
)
if success:
print("βœ… Training completed successfully!")
# Evaluate model
metrics = model.evaluate_model()
if metrics:
print(f"πŸ“Š Model Performance:")
print(f" Loss: {metrics['loss']:.4f}")
print(f" MAE: {metrics['mae']:.4f}")
print(f" MSE: {metrics['mse']:.4f}")
return model
else:
print("❌ Training failed!")
return None
def quick_train():
"""Quick training with default parameters"""
return train_new_model(num_samples=500, epochs=10)
def full_train():
"""Full training with comprehensive dataset"""
return train_new_model(num_samples=2000, epochs=30)
# Example usage and testing
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='CNN Deblurring Module')
parser.add_argument('--train', action='store_true', help='Train the model')
parser.add_argument('--quick-train', action='store_true', help='Quick training (500 samples, 10 epochs)')
parser.add_argument('--full-train', action='store_true', help='Full training (2000 samples, 30 epochs)')
parser.add_argument('--samples', type=int, default=1000, help='Number of training samples')
parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--test', action='store_true', help='Test the model')
args = parser.parse_args()
print("🎯 CNN Deblurring Module")
print("=" * 30)
if args.quick_train:
print("πŸš€ Quick Training Mode")
model = quick_train()
elif args.full_train:
print("πŸš€ Full Training Mode")
model = full_train()
elif args.train:
print(f"πŸš€ Custom Training Mode")
model = train_new_model(num_samples=args.samples, epochs=args.epochs)
elif args.test:
print("πŸ§ͺ Testing Mode")
# Create test image
test_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
# Initialize model
cnn_model = CNNDeblurModel()
# Try to load existing model
if cnn_model.load_model(cnn_model.model_path):
print(f"βœ… Loaded existing trained model")
else:
print(f"ℹ️ No trained model found, building new model")
cnn_model.build_model()
print(f"Model input shape: {cnn_model.input_shape}")
print(f"Model built: {cnn_model.model is not None}")
print(f"Model trained: {cnn_model.is_trained}")
# Test enhancement
enhanced = cnn_model.enhance_image(test_image)
print(f"Original shape: {test_image.shape}")
print(f"Enhanced shape: {enhanced.shape}")
if cnn_model.is_trained:
# Evaluate on test data
metrics = cnn_model.evaluate_model()
if metrics:
print("πŸ“Š Model Performance:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
else:
print("ℹ️ Usage options:")
print(" --test Test existing model or build new one")
print(" --quick-train Quick training (500 samples, 10 epochs)")
print(" --full-train Full training (2000 samples, 30 epochs)")
print(" --train Custom training (use --samples and --epochs)")
print("\nExamples:")
print(" python -m modules.cnn_deblurring --test")
print(" python -m modules.cnn_deblurring --quick-train")
print(" python -m modules.cnn_deblurring --train --samples 1500 --epochs 25")
print("\n🎯 CNN deblurring module ready!")