File size: 6,517 Bytes
925fbb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
Sprite Image Enhancement Module
Uses Real-ESRGAN for high-quality upscaling
"""

import cv2
import numpy as np
import torch
from PIL import Image
import os

class SpriteProcessor:
    """Processor for enhancing sprite sheet images"""
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self._load_model()
    
    def _load_model(self):
        """Load Real-ESRGAN model"""
        try:
            from realesrgan import RealESRGANer
            from basicsr.archs.rrdbnet_arch import RRDBNet
            
            # Create model
            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                           num_block=23, num_grow_ch=32, scale=4)
            
            # Initialize Real-ESRGAN
            model_path = "weights/RealESRGAN_x4plus.pth"
            
            if os.path.exists(model_path):
                self.model = RealESRGANer(
                    scale=4,
                    model_path=model_path,
                    model=model,
                    tile=0,
                    pre_pad=0,
                    half=False,
                    device=self.device
                )
            else:
                print("Warning: Real-ESRGAN model not found, using fallback enhancement")
                self.model = None
                
        except Exception as e:
            print(f"Error loading Real-ESRGAN: {e}")
            self.model = None
    
    def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray:
        """
        Enhance image quality using Real-ESRGAN or fallback methods
        
        Args:
            image: Input image (BGR or BGRA)
            scale: Upscaling factor (2 or 4)
        
        Returns:
            Enhanced image
        """
        # Handle alpha channel
        has_alpha = len(image.shape) == 3 and image.shape[2] == 4
        
        if has_alpha:
            # Separate alpha channel
            bgr = image[:, :, :3]
            alpha = image[:, :, 3]
        else:
            bgr = image
            alpha = None
        
        # Enhance RGB channels
        if self.model is not None and scale > 1:
            try:
                # Convert BGR to RGB for the model
                rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
                
                # Apply Real-ESRGAN
                enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale)
                
                # Convert back to BGR
                enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
                
            except Exception as e:
                print(f"Real-ESRGAN failed, using fallback: {e}")
                enhanced_bgr = self._fallback_enhance(bgr, scale)
        else:
            enhanced_bgr = self._fallback_enhance(bgr, scale)
        
        # Enhance alpha channel if present
        if alpha is not None and scale > 1:
            enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale, 
                                       interpolation=cv2.INTER_NEAREST)
            
            # Merge channels
            enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha])
        else:
            enhanced_image = enhanced_bgr
        
        return enhanced_image
    
    def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray:
        """
        Fallback enhancement using OpenCV
        
        Args:
            image: Input BGR image
            scale: Upscaling factor
        
        Returns:
            Enhanced image
        """
        # Resize with high-quality interpolation
        new_width = int(image.shape[1] * scale)
        new_height = int(image.shape[0] * scale)
        
        enhanced = cv2.resize(image, (new_width, new_height), 
                             interpolation=cv2.INTER_CUBIC)
        
        # Apply sharpening
        kernel = np.array([[-1, -1, -1],
                          [-1,  9, -1],
                          [-1, -1, -1]])
        enhanced = cv2.filter2D(enhanced, -1, kernel)
        
        # Denoise
        enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21)
        
        return enhanced
    
    def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray:
        """
        Apply sharpening filter
        
        Args:
            image: Input image
            strength: Sharpening strength
        
        Returns:
            Sharpened image
        """
        kernel = np.array([[-1, -1, -1],
                          [-1,  9, -1],
                          [-1, -1, -1]]) * strength
        
        sharpened = cv2.filter2D(image, -1, kernel)
        return sharpened
    
    def remove_blur(self, image: np.ndarray) -> np.ndarray:
        """
        Reduce blur using deconvolution
        
        Args:
            image: Input image
        
        Returns:
            Deblurred image
        """
        # Create a point spread function (PSF)
        psf_size = 5
        psf = np.ones((psf_size, psf_size)) / (psf_size ** 2)
        
        # Simple deconvolution (Wiener filter approximation)
        result = image.copy()
        
        for i in range(3):  # For each channel
            channel = image[:, :, i].astype(np.float32) / 255.0
            
            # FFT
            psf_fft = np.fft.fft2(psf, s=channel.shape)
            channel_fft = np.fft.fft2(channel)
            
            # Wiener deconvolution
            K = 0.01  # Noise to signal ratio
            deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K)
            
            # Inverse FFT
            deconv = np.fft.ifft2(deconv_fft).real
            
            # Clip and convert back
            deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8)
            result[:, :, i] = deconv
        
        return result
    
    def enhance_contrast(self, image: np.ndarray) -> np.ndarray:
        """
        Enhance contrast using CLAHE
        
        Args:
            image: Input image
        
        Returns:
            Contrast-enhanced image
        """
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l = clahe.apply(l)
        
        enhanced = cv2.merge([l, a, b])
        enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
        
        return enhanced