Spaces:
tester343
/
Configuration error

File size: 10,485 Bytes
83e35a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
Ultra Compact Image Enhancer for Extreme Memory Constraints
Designed for RTX 3050 Laptop with strict <1GB VRAM limit
"""

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import gc

class UltraCompactESRGAN(nn.Module):
    """Ultra lightweight ESRGAN - only 200MB VRAM usage"""
    def __init__(self, scale=2):
        super().__init__()
        self.scale = scale
        
        # Ultra compact architecture
        nf = 24  # Even smaller feature channels
        
        self.conv1 = nn.Conv2d(3, nf, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf, nf, 3, 1, 1)
        
        # Pixel shuffle for upsampling
        self.upscale = nn.Sequential(
            nn.Conv2d(nf, 3 * scale * scale, 3, 1, 1),
            nn.PixelShuffle(scale)
        )
        
        self.act = nn.ReLU(inplace=True)
        
    def forward(self, x):
        # Simple forward pass
        x1 = self.act(self.conv1(x))
        x2 = self.act(self.conv2(x1))
        x3 = self.conv3(x2)
        x = x1 + x3  # Skip connection
        x = self.upscale(x)
        return x

class MemorySafeEnhancer:
    """Memory-safe enhancer that guarantees <1GB VRAM usage"""
    
    def __init__(self):
        self.device = self._setup_device()
        self.model = None
        self.tile_size = 64  # Very small tiles
        self.scale = 2  # 2x max for 2K output
        
        # Load model
        self._load_model()
        
    def _setup_device(self):
        """Setup device with strict memory limits"""
        if torch.cuda.is_available():
            # Clear any existing allocations
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
            # Set strict memory limit
            torch.cuda.set_per_process_memory_fraction(0.3)  # Only 30% of VRAM
            
            device = torch.device('cuda')
            print(f"πŸš€ Using GPU: {torch.cuda.get_device_name(0)}")
            
            # Print available memory
            total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            print(f"πŸ“Š Total VRAM: {total:.1f}GB, Using max: {total*0.3:.1f}GB")
        else:
            device = torch.device('cpu')
            print("πŸ’» Using CPU")
            
        return device
    
    def _load_model(self):
        """Load ultra compact model"""
        try:
            print("πŸ”„ Loading ultra-compact model...")
            
            self.model = UltraCompactESRGAN(scale=self.scale)
            self.model = self.model.to(self.device)
            self.model.eval()
            
            # Use half precision on GPU
            if self.device.type == 'cuda':
                self.model = self.model.half()
            
            # Calculate model size
            param_size = sum(p.numel() for p in self.model.parameters())
            model_mb = param_size * 2 / (1024**2)  # 2 bytes for FP16
            print(f"βœ… Model loaded: {model_mb:.1f}MB")
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            self.model = None
    
    def enhance_image(self, image_path: str, output_path: str = None) -> str:
        """Enhance image with guaranteed low memory usage"""
        if output_path is None:
            output_path = image_path.replace('.', '_enhanced.')
            
        print(f"🎨 Enhancing {os.path.basename(image_path)}...")
        
        try:
            # Read image
            img = cv2.imread(image_path)
            if img is None:
                print(f"❌ Failed to read image")
                return image_path
                
            h, w = img.shape[:2]
            print(f"  Input: {w}x{h}")
            
            # Use fallback for very large images
            if h > 2048 or w > 2048:
                print("  ⚠️ Large image, using CPU fallback")
                enhanced = self._cpu_upscale(img)
            elif self.model is not None:
                enhanced = self._enhance_with_model(img)
            else:
                enhanced = self._cpu_upscale(img)
            
            # Ensure 2K limit
            h, w = enhanced.shape[:2]
            if w > 2048 or h > 1080:
                scale = min(2048/w, 1080/h)
                new_w = int(w * scale)
                new_h = int(h * scale)
                enhanced = cv2.resize(enhanced, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
                print(f"  πŸ“ Resizing from {w}x{h} to {new_w}x{new_h} (2K limit)")
            
            # Save result
            cv2.imwrite(output_path, enhanced, [cv2.IMWRITE_JPEG_QUALITY, 95])
            
            new_h, new_w = enhanced.shape[:2]
            print(f"  βœ… Output: {new_w}x{new_h}")
            
            # Force memory cleanup
            self._cleanup_memory()
            
            return output_path
            
        except Exception as e:
            print(f"  ❌ Enhancement failed: {e}")
            # Try CPU fallback
            try:
                img = cv2.imread(image_path)
                enhanced = self._cpu_upscale(img)
                cv2.imwrite(output_path, enhanced)
                return output_path
            except:
                return image_path
    
    def _enhance_with_model(self, img):
        """Enhance using model with extreme memory safety"""
        h, w = img.shape[:2]
        
        # Output image (on CPU to save GPU memory)
        output = np.zeros((h * self.scale, w * self.scale, 3), dtype=np.uint8)
        
        # Process in very small tiles
        tile_size = self.tile_size
        
        print(f"  Processing {tile_size}x{tile_size} tiles...")
        
        for y in range(0, h, tile_size):
            for x in range(0, w, tile_size):
                # Extract tile
                y_end = min(y + tile_size, h)
                x_end = min(x + tile_size, w)
                tile = img[y:y_end, x:x_end]
                
                # Skip if tile is too small
                if tile.shape[0] < 4 or tile.shape[1] < 4:
                    continue
                
                try:
                    # Process tile
                    enhanced_tile = self._process_single_tile(tile)
                    
                    # Place in output
                    out_y = y * self.scale
                    out_x = x * self.scale
                    out_y_end = out_y + enhanced_tile.shape[0]
                    out_x_end = out_x + enhanced_tile.shape[1]
                    
                    output[out_y:out_y_end, out_x:out_x_end] = enhanced_tile
                    
                except Exception as e:
                    # If tile fails, use CPU upscale for that tile
                    fallback = cv2.resize(tile, (tile.shape[1]*self.scale, tile.shape[0]*self.scale), 
                                        interpolation=cv2.INTER_CUBIC)
                    out_y = y * self.scale
                    out_x = x * self.scale
                    output[out_y:out_y+fallback.shape[0], out_x:out_x+fallback.shape[1]] = fallback
                
                # Force memory cleanup after each tile
                if self.device.type == 'cuda':
                    torch.cuda.empty_cache()
        
        return output
    
    def _process_single_tile(self, tile):
        """Process a single tile with proper error handling"""
        # Convert to tensor
        tile_rgb = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB)
        tile_norm = tile_rgb.astype(np.float32) / 255.0
        
        # Create tensor with correct shape
        tile_tensor = torch.from_numpy(tile_norm).permute(2, 0, 1).unsqueeze(0)
        tile_tensor = tile_tensor.to(self.device)
        
        # Convert to half precision if using GPU
        if self.device.type == 'cuda':
            tile_tensor = tile_tensor.half()
        
        # Process
        with torch.no_grad():
            enhanced_tensor = self.model(tile_tensor)
        
        # Convert back to numpy
        enhanced = enhanced_tensor.squeeze(0).permute(1, 2, 0)
        enhanced = enhanced.cpu().float().numpy()
        enhanced = (enhanced * 255).clip(0, 255).astype(np.uint8)
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_RGB2BGR)
        
        # Clean up tensors
        del tile_tensor, enhanced_tensor
        
        return enhanced
    
    def _cpu_upscale(self, img):
        """CPU-only upscaling fallback"""
        print("  πŸ“ˆ Using CPU upscaling...")
        
        # High-quality CPU upscaling (max 2K)
        h, w = img.shape[:2]
        scale_factor = min(self.scale, 2048/w, 1080/h)
        new_w = int(w * scale_factor)
        new_h = int(h * scale_factor)
        
        # Use multiple interpolation methods and blend
        cubic = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        lanczos = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
        
        # Blend for better quality
        result = cv2.addWeighted(cubic, 0.5, lanczos, 0.5, 0)
        
        # Mild sharpening (properly normalized)
        kernel = np.array([[0, -1, 0], 
                          [-1, 5, -1], 
                          [0, -1, 0]], dtype=np.float32)
        result = cv2.filter2D(result, -1, kernel)
        
        return result
    
    def _cleanup_memory(self):
        """Aggressive memory cleanup"""
        gc.collect()
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
    
    def get_memory_usage(self):
        """Get current memory usage"""
        if self.device.type == 'cuda':
            allocated = torch.cuda.memory_allocated() / (1024**2)
            reserved = torch.cuda.memory_reserved() / (1024**2)
            return f"Allocated: {allocated:.1f}MB, Reserved: {reserved:.1f}MB"
        return "Using CPU"

# Global instance
_memory_safe_enhancer = None

def get_memory_safe_enhancer():
    """Get or create memory-safe enhancer"""
    global _memory_safe_enhancer
    if _memory_safe_enhancer is None:
        _memory_safe_enhancer = MemorySafeEnhancer()
    return _memory_safe_enhancer

# Simple API
def enhance_image_safe(image_path: str, output_path: str = None) -> str:
    """Enhance image with guaranteed <1GB VRAM usage"""
    enhancer = get_memory_safe_enhancer()
    return enhancer.enhance_image(image_path, output_path)