Spaces:
tester343
/
Configuration error

File size: 14,054 Bytes
83e35a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
"""
Lightweight AI Enhancement for Limited VRAM (< 4GB)
Optimized for RTX 3050 Laptop GPU
Uses efficient models with excellent quality
"""

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import requests
from tqdm import tqdm
from typing import Optional, Dict, Any, Tuple
import warnings
warnings.filterwarnings('ignore')

# Lightweight ESRGAN Architecture
class RRDBNet_arch(nn.Module):
    """Lightweight RRDB Net for ESRGAN - optimized for low VRAM"""
    def __init__(self, in_nc=3, out_nc=3, nf=32, nb=16):  # Reduced from 64/23 to 32/16
        super(RRDBNet_arch, self).__init__()
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(fea)
        fea = fea + trunk
        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))
        return out

class LightweightEnhancer:
    """Lightweight AI enhancer for <4GB VRAM"""
    
    def __init__(self, device=None):
        """Initialize lightweight enhancer"""
        
        # Set device
        if device is None:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
                print(f"πŸš€ Using GPU: {torch.cuda.get_device_name(0)}")
                
                # RTX 3050 Laptop optimization
                torch.backends.cudnn.benchmark = True
                torch.cuda.set_per_process_memory_fraction(0.7)  # Use only 70% VRAM
                
                # Get VRAM info
                props = torch.cuda.get_device_properties(0)
                self.vram_gb = props.total_memory / (1024**3)
                print(f"πŸ“Š VRAM: {self.vram_gb:.1f} GB")
                
            else:
                self.device = torch.device('cpu')
                print("πŸ’» Using CPU (GPU not available)")
                self.vram_gb = 0
        else:
            self.device = device
            self.vram_gb = 4  # Assume 4GB
            
        # Model storage
        self.model_dir = 'models_lightweight'
        os.makedirs(self.model_dir, exist_ok=True)
        
        # Models
        self.esrgan_model = None
        self.face_model = None
        
        # Settings based on VRAM
        if self.vram_gb < 4:
            self.tile_size = 256  # Smaller tiles for <4GB
            self.use_fp16 = True  # Force FP16
        else:
            self.tile_size = 384
            self.use_fp16 = True
            
    def load_lightweight_esrgan(self):
        """Load lightweight ESRGAN model"""
        try:
            print("πŸ”„ Loading lightweight ESRGAN...")
            
            # Create lightweight model
            self.esrgan_model = RRDBNet_arch()
            
            # Try to load pretrained weights if available
            model_path = os.path.join(self.model_dir, 'lightweight_esrgan.pth')
            if os.path.exists(model_path):
                self.esrgan_model.load_state_dict(torch.load(model_path, map_location=self.device))
                print("βœ… Loaded pretrained lightweight model")
            else:
                print("⚠️ No pretrained model found, using random initialization")
                # In practice, you'd train this or download a pretrained one
                
            self.esrgan_model = self.esrgan_model.to(self.device)
            self.esrgan_model.eval()
            
            # Convert to FP16 if using GPU
            if self.use_fp16 and self.device.type == 'cuda':
                self.esrgan_model = self.esrgan_model.half()
                print("βœ… Using FP16 for memory efficiency")
                
            return True
            
        except Exception as e:
            print(f"❌ Failed to load lightweight ESRGAN: {e}")
            return False
            
    def enhance_with_lightweight_esrgan(self, img):
        """Enhance using lightweight ESRGAN with tiling"""
        if self.esrgan_model is None:
            if not self.load_lightweight_esrgan():
                return self.fallback_upscale(img, 2)
                
        try:
            # Convert to tensor
            img_tensor = self.img_to_tensor(img)
            
            # Process with tiling for low VRAM
            result = self.process_with_tiles(img_tensor, self.esrgan_model, scale=2)
            
            # Convert back to numpy
            result = self.tensor_to_img(result)
            
            return result
            
        except Exception as e:
            print(f"❌ Enhancement failed: {e}")
            return self.fallback_upscale(img, 2)
            
    def process_with_tiles(self, img_tensor, model, scale=2):
        """Process image in tiles to save VRAM"""
        _, _, h, w = img_tensor.shape
        
        # Calculate output size (max 2K)
        target_h = h * scale
        target_w = w * scale
        
        # Apply 2K limit
        if target_w > 2048 or target_h > 1080:
            limit_scale = min(2048/target_w, 1080/target_h)
            out_w = int(target_w * limit_scale)
            out_h = int(target_h * limit_scale)
            print(f"  πŸ“ Limiting output to {out_w}x{out_h} (2K max)")
        else:
            out_h, out_w = target_h, target_w
        output = torch.zeros((1, 3, out_h, out_w), device=self.device)
        
        # Tile processing
        tile_size = self.tile_size
        pad = 16  # Overlap to avoid seams
        
        for y in range(0, h, tile_size - pad):
            for x in range(0, w, tile_size - pad):
                # Extract tile
                y_end = min(y + tile_size, h)
                x_end = min(x + tile_size, w)
                tile = img_tensor[:, :, y:y_end, x:x_end]
                
                # Process tile
                with torch.no_grad():
                    if self.use_fp16 and self.device.type == 'cuda':
                        tile = tile.half()
                    
                    tile_out = model(tile)
                    
                    if self.use_fp16:
                        tile_out = tile_out.float()
                
                # Place tile in output
                out_y = y * scale
                out_x = x * scale
                out_y_end = min(out_y + tile_out.shape[2], out_h)
                out_x_end = min(out_x + tile_out.shape[3], out_w)
                
                output[:, :, out_y:out_y_end, out_x:out_x_end] = tile_out[:, :, :out_y_end-out_y, :out_x_end-out_x]
                
                # Clear cache to save memory
                if self.device.type == 'cuda':
                    torch.cuda.empty_cache()
                    
        return output
        
    def img_to_tensor(self, img):
        """Convert image to tensor"""
        if isinstance(img, Image.Image):
            img = np.array(img)
            
        # Ensure RGB
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:
            img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
        elif img.shape[2] == 3 and isinstance(img, np.ndarray):
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
        # Normalize to [0, 1]
        img = img.astype(np.float32) / 255.0
        
        # Convert to tensor
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
        
        return img_tensor.to(self.device)
        
    def tensor_to_img(self, tensor):
        """Convert tensor to image"""
        img = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
        img = (img * 255).clip(0, 255).astype(np.uint8)
        return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        
    def fallback_upscale(self, img, scale):
        """Fallback upscaling using OpenCV with 2K limit"""
        print("  πŸ“ˆ Using optimized fallback upscaling...")
        
        h, w = img.shape[:2]
        
        # Calculate new size with 2K limit
        target_scale = min(scale, 2048/w, 1080/h)
        new_w = int(w * target_scale)
        new_h = int(h * target_scale)
        
        # Use EDSR-inspired upscaling
        # First, upscale with CUBIC
        upscaled = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        
        # Apply sharpening
        kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]]) / 1
        upscaled = cv2.filter2D(upscaled, -1, kernel)
        
        # Reduce noise
        upscaled = cv2.bilateralFilter(upscaled, 5, 50, 50)
        
        return upscaled
        
    def enhance_faces_lightweight(self, img):
        """Lightweight face enhancement"""
        try:
            # Detect faces using OpenCV
            face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray, 1.1, 4)
            
            if len(faces) == 0:
                return img
                
            print(f"  πŸ‘€ Enhancing {len(faces)} faces...")
            
            for (x, y, w, h) in faces:
                # Extract face with padding
                pad = int(w * 0.1)
                x_start = max(0, x - pad)
                y_start = max(0, y - pad)
                x_end = min(img.shape[1], x + w + pad)
                y_end = min(img.shape[0], y + h + pad)
                
                face = img[y_start:y_end, x_start:x_end]
                
                # Enhance face
                face = self.enhance_face_region_lightweight(face)
                
                # Put back
                img[y_start:y_end, x_start:x_end] = face
                
            return img
            
        except Exception as e:
            print(f"⚠️ Face enhancement failed: {e}")
            return img
            
    def enhance_face_region_lightweight(self, face):
        """Lightweight face enhancement"""
        # 1. Denoise
        face = cv2.bilateralFilter(face, 9, 75, 75)
        
        # 2. Enhance details
        lab = cv2.cvtColor(face, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        
        # CLAHE on L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        
        face = cv2.merge([l, a, b])
        face = cv2.cvtColor(face, cv2.COLOR_LAB2BGR)
        
        # 3. Subtle sharpening
        kernel = np.array([[0,-1,0], [-1,5,-1], [0,-1,0]]) / 1
        face = cv2.filter2D(face, -1, kernel)
        
        return face
        
    def enhance_image_pipeline(self, image_path: str, output_path: str = None) -> str:
        """Complete enhancement pipeline for low VRAM"""
        print(f"🎨 Enhancing {os.path.basename(image_path)} (Lightweight Mode)...")
        
        try:
            # Load image
            img = cv2.imread(image_path)
            if img is None:
                print(f"❌ Failed to load image: {image_path}")
                return image_path
                
            original_shape = img.shape[:2]
            print(f"  Original: {original_shape[1]}x{original_shape[0]}")
            
            # Step 1: Lightweight super resolution
            print("  πŸš€ Applying lightweight upscaling (max 2K)...")
            print(f"  πŸ“ Input: {img.shape[1]}x{img.shape[0]}")
            enhanced = self.enhance_with_lightweight_esrgan(img)
            
            # Step 2: Face enhancement
            print("  πŸ‘€ Enhancing faces...")
            enhanced = self.enhance_faces_lightweight(enhanced)
            
            # Step 3: Final color correction
            print("  🎨 Applying color correction...")
            enhanced = self.color_correction(enhanced)
            
            # Save
            if output_path is None:
                output_path = image_path.replace('.', '_enhanced.')
                
            cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
            
            new_shape = enhanced.shape[:2]
            print(f"  βœ… Enhanced: {new_shape[1]}x{new_shape[0]}")
            
            # Clear memory
            self.clear_memory()
            
            return output_path
            
        except Exception as e:
            print(f"❌ Pipeline failed: {e}")
            return image_path
            
    def color_correction(self, img):
        """Lightweight color correction"""
        # Convert to LAB
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        
        # Enhance L channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        
        # Slight color boost
        a = cv2.convertScaleAbs(a, alpha=1.1, beta=0)
        b = cv2.convertScaleAbs(b, alpha=1.1, beta=0)
        
        # Merge and convert back
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
        
        return enhanced
        
    def clear_memory(self):
        """Clear GPU memory"""
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
# Global instance
_lightweight_enhancer = None

def get_lightweight_enhancer():
    """Get or create global lightweight enhancer"""
    global _lightweight_enhancer
    if _lightweight_enhancer is None:
        _lightweight_enhancer = LightweightEnhancer()
    return _lightweight_enhancer