Spaces:
tester343
/
Configuration error

File size: 19,238 Bytes
83e35a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
"""
Advanced Image Enhancement using State-of-the-Art AI Models
Real-ESRGAN, GFPGAN, and other cutting-edge models
Optimized for NVIDIA RTX 3050
"""

import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image, ImageEnhance, ImageFilter
import os
import requests
from io import BytesIO
import time
from typing import Optional, Tuple
try:
    from backend.ai_model_manager import get_ai_model_manager
    AI_MODELS_AVAILABLE = True
except ImportError:
    AI_MODELS_AVAILABLE = False
    print("⚠️ AI models not available, using lightweight enhancer")
    
from backend.lightweight_ai_enhancer import get_lightweight_enhancer
from backend.compact_ai_models import CompactAIEnhancer
from backend.ultra_compact_enhancer import get_memory_safe_enhancer

class AdvancedImageEnhancer:
    """Advanced image enhancement using state-of-the-art AI models"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"🎯 Using device: {self.device}")
        
        # Check VRAM and decide which enhancer to use
        self.use_lightweight = True
        if self.device.type == 'cuda':
            props = torch.cuda.get_device_properties(0)
            vram_gb = props.total_memory / (1024**3)
            print(f"πŸ“Š VRAM: {vram_gb:.1f} GB")
            
            # Use lightweight for <6GB VRAM or if heavy models not available
            if vram_gb < 6 or not AI_MODELS_AVAILABLE:
                self.use_lightweight = True
                print("πŸš€ Using lightweight enhancer (optimized for <4GB VRAM)")
            else:
                self.use_lightweight = False
        
        # Initialize appropriate manager
        if self.use_lightweight:
            # Use memory-safe enhancer for <6GB VRAM
            print("πŸš€ Using memory-safe AI enhancer (<1GB VRAM)")
            self.enhancer = get_memory_safe_enhancer()
            self.ai_manager = None
            self.compact_realesrgan = None
        else:
            self.ai_manager = get_ai_model_manager()
            self.enhancer = None
            self.compact_realesrgan = None
        
        # Enhancement settings
        self.use_ai_models = os.getenv('USE_AI_MODELS', '1') == '1'
        self.enhance_faces = os.getenv('ENHANCE_FACES', '1') == '1'
        self.use_anime_model = False  # Will be set based on content
        
        # Initialize models
        self._load_models()
        
    def _load_models(self):
        """Load AI enhancement models"""
        try:
            if self.use_lightweight:
                print("πŸš€ Loading lightweight AI models...")
                # Lightweight models load on demand
                self.advanced_available = True
                print("βœ… Lightweight enhancer ready")
            else:
                print("πŸš€ Loading advanced AI models...")
                
                if self.use_ai_models and self.ai_manager:
                    # Load Real-ESRGAN for super resolution
                    self.ai_manager.load_realesrgan('RealESRGAN_x4plus')
                    
                    # Pre-load anime model for comic style
                    self.ai_manager.load_realesrgan('RealESRGAN_x4plus_anime_6B')
                    
                    # Load GFPGAN for face enhancement
                    if self.enhance_faces:
                        self.ai_manager.load_gfpgan()
                    
                    self.advanced_available = True
                    print("βœ… AI models loaded successfully")
                else:
                    print("⚠️ AI models disabled, using traditional methods")
                    self.advanced_available = False
            
        except Exception as e:
            print(f"⚠️ Models failed to load: {e}")
            print("⚠️ Falling back to traditional enhancement methods")
            self.advanced_available = False
    
    def enhance_image(self, image_path: str, output_path: str = None) -> str:
        """Apply advanced image enhancement"""
        if output_path is None:
            output_path = image_path
        
        print(f"πŸš€ Enhancing image: {os.path.basename(image_path)}")
        
        try:
            # Load image
            img = cv2.imread(image_path)
            if img is None:
                print(f"❌ Failed to load image: {image_path}")
                return image_path
            
            # Apply enhancement pipeline - pass image_path for compact models
            enhanced_img = self._apply_enhancement_pipeline(img, image_path)
            
            # Save enhanced image with maximum quality
            cv2.imwrite(output_path, enhanced_img, [cv2.IMWRITE_JPEG_QUALITY, 100])
            
            print(f"βœ… Enhanced image saved: {os.path.basename(output_path)}")
            return output_path
            
        except Exception as e:
            print(f"❌ Enhancement failed: {e}")
            return image_path
    
    def _apply_enhancement_pipeline(self, img: np.ndarray, image_path: str = None) -> np.ndarray:
        """Apply complete enhancement pipeline with AI models"""
        original_img = img.copy()
        
        print("🎨 Applying AI-powered enhancement pipeline...")
        
        # Detect if image is anime/comic style
        self.use_anime_model = self._detect_anime_style(img)
        
        if self.advanced_available and self.use_ai_models:
            try:
                if self.use_lightweight:
                    # Use memory-safe enhancer for <4GB VRAM
                    print("  πŸš€ Applying memory-safe AI enhancement...")
                    
                    # Save current image temporarily
                    temp_path = image_path.replace('.', '_temp.')
                    cv2.imwrite(temp_path, img)
                    
                    # Process with memory-safe enhancer
                    enhanced_path = self.enhancer.enhance_image(
                        temp_path,
                        temp_path.replace('_temp.', '_enhanced.')
                    )
                    
                    # Read enhanced image
                    img = cv2.imread(enhanced_path)
                    
                    # Clean up temp files
                    if os.path.exists(temp_path):
                        os.remove(temp_path)
                    if os.path.exists(enhanced_path) and enhanced_path != image_path:
                        os.remove(enhanced_path)
                    
                    print("  βœ… Memory-safe enhancement complete")
                    
                    # Show memory usage
                    if hasattr(self.enhancer, 'get_memory_usage'):
                        print(f"  πŸ’Ύ Memory: {self.enhancer.get_memory_usage()}")
                else:
                    # Use full AI models for >6GB VRAM
                    print("  πŸš€ Applying AI super resolution...")
                    img = self.ai_manager.enhance_image_realesrgan(
                        img, 
                        use_anime_model=self.use_anime_model
                    )
                    
                    # 2. AI Face Enhancement with GFPGAN
                    if self.enhance_faces:
                        print("  πŸ‘€ Enhancing faces with AI...")
                        img = self.ai_manager.enhance_face_gfpgan(img)
                    
                    # 3. Post-processing
                    img = self.ai_manager.post_process(img)
                    
                    # Clear GPU memory
                    self.ai_manager.clear_memory()
                
                return img
                
            except Exception as e:
                print(f"⚠️ AI enhancement failed: {e}, using fallback")
                img = original_img
        
        # Fallback to traditional methods if AI models not available
        print("  πŸ“ˆ Using traditional enhancement methods...")
        
        # 1. Traditional Super Resolution
        img = self._apply_super_resolution_advanced(img)
        
        # 2. Advanced Color Enhancement
        img = self._enhance_colors_advanced(img)
        
        # 3. Advanced Noise Reduction
        img = self._reduce_noise_advanced(img)
        
        # 4. Advanced Sharpness Enhancement
        img = self._enhance_sharpness_advanced(img)
        
        # 5. Advanced Dynamic Range Optimization
        img = self._optimize_dynamic_range_advanced(img)
        
        # 6. Traditional Face Enhancement
        img = self._enhance_faces_advanced(img)
        
        return img
    
    def _apply_super_resolution_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced super resolution (4x upscaling)"""
        try:
            print("πŸ“ˆ Applying advanced super resolution (4x upscaling)...")
            
            # Get original dimensions
            height, width = img.shape[:2]
            
            # Calculate target dimensions (max 2K - 2048x1080)
            scale_factor = min(2048 / width, 1080 / height, 2.0)  # Max 2x upscaling
            target_width = int(width * scale_factor)
            target_height = int(height * scale_factor)
            
            # Use LANCZOS interpolation for highest quality
            img = cv2.resize(img, (target_width, target_height), 
                           interpolation=cv2.INTER_LANCZOS4)
            
            # Apply additional sharpening after upscaling
            kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
            img = cv2.filter2D(img, -1, kernel)
            
            print(f"βœ… Super resolution completed: {width}x{height} β†’ {target_width}x{target_height}")
            
        except Exception as e:
            print(f"⚠️ Super resolution failed: {e}")
        
        return img
    
    def _enhance_colors_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced color enhancement"""
        try:
            print("🎨 Applying advanced color enhancement...")
            
            # Convert to LAB color space for better color processing
            lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
            
            # Enhance L channel (lightness) with CLAHE
            clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
            lab[:,:,0] = clahe.apply(lab[:,:,0])
            
            # Enhance A and B channels (color) with adaptive scaling
            lab[:,:,1] = cv2.convertScaleAbs(lab[:,:,1], alpha=1.3, beta=10)
            lab[:,:,2] = cv2.convertScaleAbs(lab[:,:,2], alpha=1.3, beta=10)
            
            # Convert back to BGR
            enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
            
            # Additional color saturation enhancement
            hsv = cv2.cvtColor(enhanced, cv2.COLOR_BGR2HSV)
            hsv[:,:,1] = cv2.convertScaleAbs(hsv[:,:,1], alpha=1.4, beta=0)  # Increase saturation
            enhanced = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
            
        except Exception as e:
            print(f"⚠️ Color enhancement failed: {e}")
            enhanced = img
        
        return enhanced
    
    def _reduce_noise_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced noise reduction"""
        try:
            print("🧹 Applying advanced noise reduction...")
            
            # Multi-stage noise reduction
            
            # 1. Bilateral filter for edge-preserving smoothing
            denoised = cv2.bilateralFilter(img, 9, 75, 75)
            
            # 2. Non-local means denoising for additional noise reduction
            denoised = cv2.fastNlMeansDenoisingColored(denoised, None, 10, 10, 7, 21)
            
            # 3. Gaussian blur for final smoothing
            denoised = cv2.GaussianBlur(denoised, (3, 3), 0)
            
            # 4. Edge-preserving filter
            denoised = cv2.edgePreservingFilter(denoised, flags=1, sigma_s=60, sigma_r=0.4)
            
        except Exception as e:
            print(f"⚠️ Noise reduction failed: {e}")
            denoised = img
        
        return denoised
    
    def _enhance_sharpness_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced sharpness enhancement"""
        try:
            print("πŸ”ͺ Applying advanced sharpness enhancement...")
            
            # Multi-stage sharpening
            
            # 1. Unsharp masking
            gaussian = cv2.GaussianBlur(img, (0, 0), 2.0)
            sharpened = cv2.addWeighted(img, 1.5, gaussian, -0.5, 0)
            
            # 2. Edge enhancement
            kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
            sharpened = cv2.filter2D(sharpened, -1, kernel)
            
            # 3. Laplacian sharpening
            gray = cv2.cvtColor(sharpened, cv2.COLOR_BGR2GRAY)
            laplacian = cv2.Laplacian(gray, cv2.CV_64F)
            laplacian = np.uint8(np.absolute(laplacian))
            sharpened = cv2.addWeighted(sharpened, 1.0, cv2.cvtColor(laplacian, cv2.COLOR_GRAY2BGR), 0.3, 0)
            
        except Exception as e:
            print(f"⚠️ Sharpness enhancement failed: {e}")
            sharpened = img
        
        return sharpened
    
    def _optimize_dynamic_range_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced dynamic range optimization"""
        try:
            print("πŸ“Š Applying advanced dynamic range optimization...")
            
            # Convert to LAB color space
            lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
            
            # Apply CLAHE to L channel for better contrast
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            lab[:,:,0] = clahe.apply(lab[:,:,0])
            
            # Enhance contrast in A and B channels
            lab[:,:,1] = cv2.convertScaleAbs(lab[:,:,1], alpha=1.2, beta=0)
            lab[:,:,2] = cv2.convertScaleAbs(lab[:,:,2], alpha=1.2, beta=0)
            
            # Convert back to BGR
            optimized = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
            
            # Additional contrast enhancement
            optimized = cv2.convertScaleAbs(optimized, alpha=1.1, beta=5)
            
        except Exception as e:
            print(f"⚠️ Dynamic range optimization failed: {e}")
            optimized = img
        
        return optimized
    
    def _enhance_faces_advanced(self, img: np.ndarray) -> np.ndarray:
        """Advanced face enhancement"""
        try:
            print("πŸ‘€ Applying advanced face enhancement...")
            
            # Load face cascade
            face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
            
            # Detect faces
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, 1.1, 4)
            
            if len(faces) > 0:
                print(f"🎭 Found {len(faces)} faces, applying enhancement...")
                
                for (x, y, w, h) in faces:
                    # Extract face region
                    face_roi = img[y:y+h, x:x+w]
                    
                    # Apply face-specific enhancement
                    enhanced_face = self._enhance_face_region(face_roi)
                    
                    # Replace face region
                    img[y:y+h, x:x+w] = enhanced_face
            else:
                print("πŸ‘€ No faces detected, skipping face enhancement")
                
        except Exception as e:
            print(f"⚠️ Face enhancement failed: {e}")
        
        return img
    
    def _enhance_face_region(self, face_img: np.ndarray) -> np.ndarray:
        """Enhance a specific face region"""
        try:
            # Apply gentle smoothing to face
            enhanced = cv2.bilateralFilter(face_img, 5, 50, 50)
            
            # Enhance skin tone
            hsv = cv2.cvtColor(enhanced, cv2.COLOR_BGR2HSV)
            hsv[:,:,1] = cv2.convertScaleAbs(hsv[:,:,1], alpha=1.1, beta=0)  # Gentle saturation boost
            enhanced = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
            
            # Apply subtle sharpening
            kernel = np.array([[-0.5,-0.5,-0.5], [-0.5,5,-0.5], [-0.5,-0.5,-0.5]])
            enhanced = cv2.filter2D(enhanced, -1, kernel)
            
        except Exception as e:
            enhanced = face_img
        
        return enhanced
    
    def _detect_anime_style(self, img: np.ndarray) -> bool:
        """Detect if image is anime/manga/comic style"""
        try:
            # Convert to grayscale
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            
            # 1. Edge density check - anime has cleaner edges
            edges = cv2.Canny(gray, 50, 150)
            edge_density = np.sum(edges > 0) / edges.size
            
            # 2. Color count check - anime has fewer unique colors
            unique_colors = len(np.unique(img.reshape(-1, img.shape[2]), axis=0))
            
            # 3. Gradient smoothness - anime has smoother gradients
            laplacian = cv2.Laplacian(gray, cv2.CV_64F)
            gradient_variance = np.var(laplacian)
            
            # Decision logic
            is_anime = (
                edge_density < 0.15 and  # Clean edges
                unique_colors < 10000 and  # Limited color palette
                gradient_variance < 1000  # Smooth gradients
            )
            
            if is_anime:
                print("  🎌 Detected anime/comic style - using specialized model")
            
            return is_anime
            
        except Exception as e:
            print(f"⚠️ Style detection failed: {e}")
            return False
    
    def enhance_batch(self, image_paths: list, output_dir: str = None) -> list:
        """Enhance multiple images"""
        if output_dir is None:
            output_dir = "enhanced"
        
        os.makedirs(output_dir, exist_ok=True)
        enhanced_paths = []
        
        print(f"🎯 Enhancing {len(image_paths)} images with advanced techniques...")
        
        for i, image_path in enumerate(image_paths, 1):
            print(f"πŸ“Έ Processing {i}/{len(image_paths)}: {os.path.basename(image_path)}")
            
            # Generate output path
            filename = os.path.basename(image_path)
            output_path = os.path.join(output_dir, f"enhanced_{filename}")
            
            # Enhance image
            enhanced_path = self.enhance_image(image_path, output_path)
            enhanced_paths.append(enhanced_path)
        
        print(f"βœ… Enhanced {len(enhanced_paths)} images with advanced techniques")
        return enhanced_paths

# Global instance
advanced_enhancer = None

def get_advanced_enhancer():
    """Get or create global advanced enhancer instance"""
    global advanced_enhancer
    if advanced_enhancer is None:
        advanced_enhancer = AdvancedImageEnhancer()
    return advanced_enhancer

if __name__ == "__main__":
    # Test the enhancer
    enhancer = AdvancedImageEnhancer()
    print("πŸ§ͺ Advanced Image Enhancer ready for testing!")