AI-Based-Image-Deblurring-App / src /modules /iterative_enhancement.py
ganeshkumar383's picture
Upload 27 files (#2)
ecc16d3 verified
"""
Iterative Enhancement Module
===========================
Progressive image enhancement system that allows multiple rounds of improvement
with different algorithms and strengths until optimal results are achieved.
"""
import cv2
import numpy as np
from typing import Dict, List, Optional, Tuple
import logging
from .color_preservation import ColorPreserver
from .traditional_filters import TraditionalFilters, BlurType
from .blur_detection import BlurDetector
from .sharpness_analysis import SharpnessAnalyzer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class IterativeEnhancer:
"""Advanced iterative image enhancement system"""
def __init__(self):
self.blur_detector = BlurDetector()
self.sharpness_analyzer = SharpnessAnalyzer()
self.filters = TraditionalFilters()
self.color_preserver = ColorPreserver()
def progressive_enhancement(self,
image: np.ndarray,
max_iterations: int = 5,
target_sharpness: float = 800.0,
adaptive: bool = True) -> Dict:
"""
Apply progressive enhancement until target quality is reached
Args:
image: Input image
max_iterations: Maximum enhancement rounds
target_sharpness: Target sharpness score to achieve
adaptive: Whether to adapt methods based on image analysis
Returns:
dict: Enhancement results with history
"""
try:
original_image = image.copy()
current_image = image.copy()
enhancement_history = []
for iteration in range(max_iterations):
# Analyze current state
analysis = self.blur_detector.comprehensive_analysis(current_image)
sharpness = analysis['sharpness_score']
logger.info(f"Iteration {iteration + 1}: Current sharpness = {sharpness:.1f}")
# Check if target achieved
if sharpness >= target_sharpness:
logger.info(f"Target sharpness achieved in {iteration + 1} iterations")
break
# Select best enhancement method for current state
enhancement_method = self._select_optimal_method(analysis, iteration)
# Apply enhancement
enhanced_image = self._apply_enhancement(
current_image, enhancement_method, iteration
)
# Preserve colors from original
enhanced_image = self.color_preserver.preserve_colors_during_enhancement(
original_image, enhanced_image, preservation_strength=0.9
)
# Validate improvement
new_sharpness = self.blur_detector.variance_of_laplacian(enhanced_image)
improvement = new_sharpness - sharpness
# Record this iteration
enhancement_history.append({
'iteration': iteration + 1,
'method': enhancement_method['name'],
'parameters': enhancement_method['params'],
'sharpness_before': float(sharpness),
'sharpness_after': float(new_sharpness),
'improvement': float(improvement),
'cumulative_improvement': float(new_sharpness - analysis['sharpness_score'])
})
# Update current image if improvement is significant
if improvement > 5.0: # Only keep improvements above threshold
current_image = enhanced_image
logger.info(f"Applied {enhancement_method['name']}, improvement: +{improvement:.1f}")
else:
logger.info(f"Minimal improvement ({improvement:.1f}), stopping iteration")
break
# Final analysis
final_analysis = self.blur_detector.comprehensive_analysis(current_image)
final_metrics = self.sharpness_analyzer.analyze_sharpness(current_image)
total_improvement = final_analysis['sharpness_score'] - analysis['sharpness_score']
return {
'enhanced_image': current_image,
'original_image': original_image,
'iterations_performed': len(enhancement_history),
'enhancement_history': enhancement_history,
'final_sharpness': float(final_analysis['sharpness_score']),
'total_improvement': float(total_improvement),
'final_analysis': final_analysis,
'final_metrics': final_metrics,
'target_achieved': final_analysis['sharpness_score'] >= target_sharpness
}
except Exception as e:
logger.error(f"Error in progressive enhancement: {e}")
return {
'enhanced_image': image,
'original_image': image,
'iterations_performed': 0,
'error': str(e)
}
def _select_optimal_method(self, analysis: Dict, iteration: int) -> Dict:
"""Select the best enhancement method based on current image state"""
primary_type = analysis.get('primary_type', 'Unknown')
sharpness = analysis.get('sharpness_score', 0)
motion_length = analysis.get('motion_length', 0)
defocus_score = analysis.get('defocus_score', 0)
# Progressive strategy: start gentle, increase intensity
if iteration == 0:
# First iteration: gentle enhancement
if "Motion" in primary_type and motion_length > 10:
return {
'name': 'Richardson-Lucy',
'params': {'iterations': 5, 'strength': 0.5}
}
elif "Defocus" in primary_type:
return {
'name': 'Wiener Filter',
'params': {'blur_type': 'defocus', 'strength': 0.6}
}
else:
return {
'name': 'Unsharp Masking',
'params': {'sigma': 1.0, 'amount': 0.5}
}
elif iteration == 1:
# Second iteration: moderate enhancement
if sharpness < 300:
return {
'name': 'Advanced Wiener',
'params': {'adaptive': True, 'strength': 0.8}
}
else:
return {
'name': 'Multi-scale Sharpening',
'params': {'scales': [0.5, 1.0, 1.5], 'strength': 0.4}
}
elif iteration == 2:
# Third iteration: targeted enhancement
if "Motion" in primary_type:
return {
'name': 'Richardson-Lucy',
'params': {'iterations': 15, 'strength': 0.7}
}
else:
return {
'name': 'Gradient-based Sharpening',
'params': {'strength': 0.6, 'preserve_edges': True}
}
else:
# Later iterations: fine-tuning
return {
'name': 'Fine Unsharp',
'params': {'sigma': 0.5, 'amount': 0.3, 'threshold': 0.1}
}
def _apply_enhancement(self, image: np.ndarray, method: Dict, iteration: int) -> np.ndarray:
"""Apply the selected enhancement method"""
try:
method_name = method['name']
params = method['params']
if method_name == 'Richardson-Lucy':
return self._richardson_lucy_enhanced(image, params)
elif method_name == 'Wiener Filter':
return self._wiener_enhanced(image, params)
elif method_name == 'Advanced Wiener':
return self._advanced_wiener(image, params)
elif method_name == 'Unsharp Masking':
return self._unsharp_enhanced(image, params)
elif method_name == 'Multi-scale Sharpening':
return self._multiscale_sharpening(image, params)
elif method_name == 'Gradient-based Sharpening':
return self._gradient_sharpening(image, params)
elif method_name == 'Fine Unsharp':
return self._fine_unsharp(image, params)
else:
# Default fallback
return self.color_preserver.accurate_unsharp_masking(image)
except Exception as e:
logger.error(f"Error applying {method_name}: {e}")
return image
def _richardson_lucy_enhanced(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Enhanced Richardson-Lucy deconvolution"""
iterations = params.get('iterations', 10)
strength = params.get('strength', 1.0)
# Create adaptive PSF based on image analysis
psf_size = 7
sigma = 1.5 * strength
psf = self._create_adaptive_psf(image, psf_size, sigma)
return self.filters.richardson_lucy_deconvolution(image, psf, iterations)
def _wiener_enhanced(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Enhanced Wiener filtering"""
blur_type_str = params.get('blur_type', 'gaussian')
strength = params.get('strength', 1.0)
# Map blur type
blur_type = BlurType.GAUSSIAN
if blur_type_str == 'motion':
blur_type = BlurType.MOTION
elif blur_type_str == 'defocus':
blur_type = BlurType.DEFOCUS
# Create appropriate PSF
if blur_type == BlurType.MOTION:
psf = self._create_motion_psf(15, 0) # 15px horizontal motion
else:
psf = self._create_gaussian_psf(7, 1.5 * strength)
noise_var = 0.01 / strength # Lower noise assumption for stronger enhancement
return self.filters.wiener_filter(image, psf, noise_var)
def _advanced_wiener(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Advanced adaptive Wiener filtering"""
adaptive = params.get('adaptive', True)
strength = params.get('strength', 1.0)
if adaptive:
# Analyze image to determine optimal PSF
analysis = self.blur_detector.comprehensive_analysis(image)
motion_length = analysis.get('motion_length', 5)
motion_angle = analysis.get('motion_angle', 0)
if motion_length > 8:
psf = self._create_motion_psf(motion_length, motion_angle)
else:
psf = self._create_gaussian_psf(5, 1.0)
else:
psf = self._create_gaussian_psf(7, 1.5)
noise_var = 0.005 * (2.0 - strength) # Adaptive noise estimation
return self.filters.wiener_filter(image, psf, noise_var)
def _unsharp_enhanced(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Enhanced unsharp masking"""
sigma = params.get('sigma', 1.0)
amount = params.get('amount', 0.5)
return self.color_preserver.accurate_unsharp_masking(image, sigma, amount)
def _multiscale_sharpening(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Multi-scale sharpening approach"""
scales = params.get('scales', [0.5, 1.0, 1.5])
strength = params.get('strength', 0.4)
# Convert to float for precision
img_float = image.astype(np.float64)
enhanced = np.zeros_like(img_float)
for scale in scales:
# Apply unsharp masking at different scales
sigma = scale
amount = strength / len(scales)
blurred = cv2.GaussianBlur(img_float, (0, 0), sigma)
mask = img_float - blurred
scale_enhanced = img_float + amount * mask
enhanced += scale_enhanced / len(scales)
return np.clip(enhanced, 0, 255).astype(np.uint8)
def _gradient_sharpening(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Gradient-based edge-preserving sharpening"""
strength = params.get('strength', 0.6)
preserve_edges = params.get('preserve_edges', True)
# Convert to float
img_float = image.astype(np.float64)
# Calculate gradients
if len(img_float.shape) == 3:
# Process each channel
enhanced_channels = []
for i in range(3):
channel = img_float[:, :, i]
# Sobel gradients
grad_x = cv2.Sobel(channel, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(channel, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
# Edge-preserving enhancement
if preserve_edges:
# Stronger enhancement in high-gradient areas
enhancement_mask = gradient_mag / (np.max(gradient_mag) + 1e-6)
enhanced_channel = channel + strength * enhancement_mask * gradient_mag * 0.1
else:
# Uniform enhancement
enhanced_channel = channel + strength * gradient_mag * 0.1
enhanced_channels.append(enhanced_channel)
enhanced = np.stack(enhanced_channels, axis=2)
else:
# Grayscale processing
grad_x = cv2.Sobel(img_float, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(img_float, cv2.CV_64F, 0, 1, ksize=3)
gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
enhanced = img_float + strength * gradient_mag * 0.1
return np.clip(enhanced, 0, 255).astype(np.uint8)
def _fine_unsharp(self, image: np.ndarray, params: Dict) -> np.ndarray:
"""Fine-tuned unsharp masking for final enhancement"""
sigma = params.get('sigma', 0.5)
amount = params.get('amount', 0.3)
threshold = params.get('threshold', 0.1)
# Very gentle, high-quality unsharp masking
img_float = image.astype(np.float64)
blurred = cv2.GaussianBlur(img_float, (0, 0), sigma)
mask = img_float - blurred
# Apply threshold to avoid noise amplification
if threshold > 0:
mask = np.where(np.abs(mask) >= threshold * 255, mask, 0)
enhanced = img_float + amount * mask
return np.clip(enhanced, 0, 255).astype(np.uint8)
def _create_adaptive_psf(self, image: np.ndarray, size: int, sigma: float) -> np.ndarray:
"""Create adaptive PSF based on image characteristics"""
analysis = self.blur_detector.comprehensive_analysis(image)
if "Motion" in analysis.get('primary_type', ''):
motion_length = analysis.get('motion_length', 5)
motion_angle = analysis.get('motion_angle', 0)
return self._create_motion_psf(motion_length, motion_angle)
else:
return self._create_gaussian_psf(size, sigma)
def _create_gaussian_psf(self, size: int, sigma: float) -> np.ndarray:
"""Create Gaussian PSF"""
if size % 2 == 0:
size += 1
center = size // 2
x, y = np.meshgrid(np.arange(size) - center, np.arange(size) - center)
psf = np.exp(-(x**2 + y**2) / (2 * sigma**2))
return psf / np.sum(psf)
def _create_motion_psf(self, length: int, angle: float) -> np.ndarray:
"""Create motion blur PSF"""
if length < 3:
length = 3
# Create motion kernel
psf = np.zeros((length * 2 + 1, length * 2 + 1))
# Calculate motion line
angle_rad = np.deg2rad(angle)
center = length
for i in range(length):
x = int(center + i * np.cos(angle_rad))
y = int(center + i * np.sin(angle_rad))
if 0 <= x < psf.shape[0] and 0 <= y < psf.shape[1]:
psf[y, x] = 1
# Normalize
if np.sum(psf) > 0:
psf = psf / np.sum(psf)
else:
# Fallback to simple horizontal line
psf[center, center-length//2:center+length//2+1] = 1.0 / length
return psf
def enhance_progressively(image: np.ndarray,
iterations: int = 3,
target_sharpness: float = 800.0) -> Dict:
"""
Convenience function for progressive enhancement
Args:
image: Input image
iterations: Maximum iterations
target_sharpness: Target sharpness score
Returns:
dict: Enhancement results
"""
enhancer = IterativeEnhancer()
return enhancer.progressive_enhancement(image, iterations, target_sharpness)