jgitsolutions commited on
Commit
0f8236e
·
verified ·
1 Parent(s): 3d00790

Update app.py

Browse files

fix color issues

Files changed (1) hide show
  1. app.py +109 -67
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - Fixed Implementation with Color Correction
2
  import os
3
  import cv2
4
  import time
@@ -10,67 +10,76 @@ import torch.nn.functional as F
10
  from PIL import Image
11
  from functools import partial
12
 
13
- # ====================== IMPROVED ARTIFACT MITIGATION ======================
14
- def balance_color_channels(image):
15
- """Fix color channel imbalance causing yellow/green tint"""
16
- b, g, r = cv2.split(image)
 
 
17
 
18
- # Calculate channel means to detect imbalance
19
- b_mean, g_mean, r_mean = np.mean(b), np.mean(g), np.mean(r)
 
20
 
21
- # Detect yellow tint (low blue, high green/red)
22
- if b_mean < (g_mean * 0.9) or b_mean < (r_mean * 0.9):
23
- # Boost blue channel
24
- b = np.clip(b * 1.2, 0, 255).astype(np.uint8)
25
- # Reduce green slightly
26
- g = np.clip(g * 0.9, 0, 255).astype(np.uint8)
27
 
28
- return cv2.merge([b, g, r])
29
-
30
- def fix_chromatic_aberration(image):
31
- """Improved color fringing reduction"""
32
- # Reduce bilateral filter strength to preserve detail
33
- filtered = cv2.bilateralFilter(image, d=3, sigmaColor=25, sigmaSpace=5)
34
- # Fix color balance
35
- return balance_color_channels(filtered)
36
 
37
- def apply_anti_ringing(img):
38
- """Reduce halo artifacts around edges"""
 
39
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
40
- edges = cv2.Canny(gray, 100, 200)
41
- dilated = cv2.dilate(edges, np.ones((2,2), np.uint8)) # Smaller kernel
42
 
43
- mask = cv2.GaussianBlur(dilated.astype(np.float32), (0,0), sigmaX=1.5)
44
- mask = (mask / 255.0)[:,:,np.newaxis]
 
 
45
 
46
- filtered = cv2.bilateralFilter(img, d=2, sigmaColor=20, sigmaSpace=2)
47
- return (img * (1-mask*0.7) + filtered * (mask*0.7)).astype(np.uint8)
48
-
49
- def color_normalize(image):
50
- """Normalize colors to prevent shifts"""
51
- lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
52
- l, a, b = cv2.split(lab)
53
 
54
- # Adjust a and b channels to prevent color shifts
55
- a_mean, b_mean = np.mean(a), np.mean(b)
 
56
 
57
- # Pull toward neutral color if strong cast is detected
58
- if abs(a_mean - 128) > 10 or abs(b_mean - 128) > 10:
59
- a = np.clip(a * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
60
- b = np.clip(b * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
61
 
62
- return cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
 
 
 
 
 
 
 
 
63
 
64
- def hybrid_upscale(image, neural_result, blend_factor=0.65): # Reduced influence
65
- """Blend neural and traditional upscaling with color preservation"""
66
- traditional = cv2.resize(image, (neural_result.shape[1], neural_result.shape[0]),
67
- interpolation=cv2.INTER_CUBIC)
 
 
 
 
 
 
 
 
68
 
69
- # Blend with reduced neural influence to preserve original colors
70
- blended = cv2.addWeighted(neural_result, blend_factor, traditional, 1-blend_factor, 0)
 
71
 
72
- # Normalize colors
73
- return color_normalize(blended)
74
 
75
  # ====================== MODEL ARCHITECTURE ======================
76
  class SelfAttention(nn.Module):
@@ -117,6 +126,8 @@ class UltraEfficientSR(nn.Module):
117
  self.upconv2 = nn.Conv2d(64, 256, 3, padding=1)
118
  self.pixel_shuffle = nn.PixelShuffle(2)
119
  self.final = nn.Conv2d(64, 3, 3, padding=1)
 
 
120
  self.color_conv = nn.Conv2d(3, 3, 1)
121
  self._initialize_weights()
122
 
@@ -127,7 +138,7 @@ class UltraEfficientSR(nn.Module):
127
  if m.bias is not None:
128
  nn.init.zeros_(m.bias)
129
 
130
- # Initialize color conv as identity for better color preservation
131
  with torch.no_grad():
132
  identity = torch.eye(3).reshape(3, 3, 1, 1)
133
  self.color_conv.weight.copy_(identity)
@@ -156,10 +167,20 @@ class UltraEfficientSR(nn.Module):
156
 
157
  # ====================== PROCESSING PIPELINE ======================
158
  def process_tile(model, tile, scale_factor):
 
 
 
 
159
  tile_tensor = torch.tensor(tile/255.0, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
160
  with torch.no_grad():
161
  output = model(tile_tensor, scale_factor)
162
- return output.squeeze().permute(1,2,0).clamp(0,1).numpy() * 255
 
 
 
 
 
 
163
 
164
  def create_pyramid_weights(h, w):
165
  y = np.linspace(0, 1, h)
@@ -186,7 +207,7 @@ def process_image_with_tiling(model, image, scale_factor, tile_size=256, overlap
186
  out_y1, out_x1 = y1*scale_factor, x1*scale_factor
187
  out_y2, out_x2 = y2*scale_factor, x2*scale_factor
188
 
189
- # Fix weights size to match processed tile shape
190
  weights = create_pyramid_weights(processed.shape[0], processed.shape[1])
191
 
192
  output[out_y1:out_y2, out_x1:out_x2] += processed * weights
@@ -229,31 +250,39 @@ class CPUUpscaler:
229
 
230
  if image_np.shape[2] == 4:
231
  image_np = image_np[:,:,:3]
232
-
 
 
 
 
 
 
233
  # Processing setup
234
  threads_used = self.energy_ctrl.adjust_processing(image_np.size)
235
  tile_size = self._calculate_optimal_tile_size(image_np)
236
 
 
 
 
237
  # Core processing
238
  if max(image_np.shape[:2]) > tile_size:
239
  output = process_image_with_tiling(self.model, image_np, scale_factor, tile_size)
240
  else:
241
  output = process_tile(self.model, image_np, scale_factor)
242
 
243
- # Enhanced artifact mitigation pipeline
244
- output = fix_chromatic_aberration(output)
245
- output = apply_anti_ringing(output)
246
- output = cv2.edgePreservingFilter(output, flags=cv2.NORMCONV_FILTER,
247
- sigma_s=40, sigma_r=0.3)
248
- output = hybrid_upscale(image_np, output)
 
 
 
 
249
 
250
- # Final color correction for extreme cases
251
- hsv = cv2.cvtColor(output, cv2.COLOR_RGB2HSV)
252
- h, s, v = cv2.split(hsv)
253
- # Reduce saturation in yellow regions
254
- yellow_mask = np.logical_and(h > 20, h < 35)
255
- s[yellow_mask] = s[yellow_mask] * 0.7
256
- output = cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2RGB)
257
 
258
  # Metrics
259
  metrics = {
@@ -261,10 +290,23 @@ class CPUUpscaler:
261
  "input_resolution": f"{image_np.shape[1]}x{image_np.shape[0]}",
262
  "output_resolution": f"{output.shape[1]}x{output.shape[0]}",
263
  "threads_used": threads_used,
264
- "tile_size": tile_size
 
265
  }
266
 
267
  return Image.fromarray(output), metrics
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  # ====================== GRADIO INTERFACE ======================
270
  def create_interface():
 
1
+ # app.py - Complete Color-Corrected Implementation
2
  import os
3
  import cv2
4
  import time
 
10
  from PIL import Image
11
  from functools import partial
12
 
13
+ # ====================== COLOR PRESERVATION FUNCTIONS ======================
14
+ def preserve_original_colors(original, processed):
15
+ """Transfer colors from original to processed image"""
16
+ # Convert to LAB color space
17
+ original_lab = cv2.cvtColor(original, cv2.COLOR_RGB2LAB)
18
+ processed_lab = cv2.cvtColor(processed, cv2.COLOR_RGB2LAB)
19
 
20
+ # Replace color information
21
+ processed_l, _, _ = cv2.split(processed_lab)
22
+ orig_l, orig_a, orig_b = cv2.split(original_lab)
23
 
24
+ # Use luminance from processed image but color from original
25
+ # Resize original color channels to match processed dimensions
26
+ h, w = processed_l.shape[:2]
27
+ resized_a = cv2.resize(orig_a, (w, h), interpolation=cv2.INTER_LINEAR)
28
+ resized_b = cv2.resize(orig_b, (w, h), interpolation=cv2.INTER_LINEAR)
 
29
 
30
+ # Create color corrected image
31
+ color_corrected = cv2.merge([processed_l, resized_a, resized_b])
32
+ return cv2.cvtColor(color_corrected, cv2.COLOR_LAB2RGB)
 
 
 
 
 
33
 
34
+ def fix_color_cast(img):
35
+ """Remove color cast from image"""
36
+ # For grayscale/black&white images, force true grayscale
37
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
38
+ gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
 
39
 
40
+ # Detect if image is likely grayscale
41
+ diff_r = np.abs(img[:,:,0].astype(np.float32) - gray.astype(np.float32))
42
+ diff_g = np.abs(img[:,:,1].astype(np.float32) - gray.astype(np.float32))
43
+ diff_b = np.abs(img[:,:,2].astype(np.float32) - gray.astype(np.float32))
44
 
45
+ total_diff = (np.mean(diff_r) + np.mean(diff_g) + np.mean(diff_b))/3
 
 
 
 
 
 
46
 
47
+ # If grayscale-like, force true grayscale
48
+ if total_diff < 10: # Threshold for considering an image grayscale
49
+ return gray_rgb
50
 
51
+ # Otherwise just correct color balance
52
+ b, g, r = cv2.split(img)
53
+ r_avg, g_avg, b_avg = np.mean(r), np.mean(g), np.mean(b)
 
54
 
55
+ # Compute grayscale average
56
+ gray_avg = np.mean(gray)
57
+
58
+ # Adjust channels to balance
59
+ r = np.clip(r * (gray_avg / r_avg) if r_avg > 0 else r, 0, 255).astype(np.uint8)
60
+ g = np.clip(g * (gray_avg / g_avg) if g_avg > 0 else g, 0, 255).astype(np.uint8)
61
+ b = np.clip(b * (gray_avg / b_avg) if b_avg > 0 else b, 0, 255).astype(np.uint8)
62
+
63
+ return cv2.merge([b, g, r])
64
 
65
+ def simple_edge_enhance(img):
66
+ """Enhance edges without color distortion"""
67
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
68
+ edges = cv2.Canny(gray, 50, 150)
69
+ dilated = cv2.dilate(edges, np.ones((2,2), np.uint8))
70
+
71
+ # Create edge mask
72
+ edge_mask = dilated.astype(np.float32) / 255.0
73
+
74
+ # Sharpen image while preserving colors
75
+ blurred = cv2.GaussianBlur(img, (0, 0), 3)
76
+ sharpened = cv2.addWeighted(img, 1.5, blurred, -0.5, 0)
77
 
78
+ # Apply sharpening only to edges
79
+ edge_mask = cv2.cvtColor(edge_mask[:,:,np.newaxis], cv2.COLOR_GRAY2RGB)
80
+ enhanced = img * (1 - edge_mask) + sharpened * edge_mask
81
 
82
+ return enhanced.astype(np.uint8)
 
83
 
84
  # ====================== MODEL ARCHITECTURE ======================
85
  class SelfAttention(nn.Module):
 
126
  self.upconv2 = nn.Conv2d(64, 256, 3, padding=1)
127
  self.pixel_shuffle = nn.PixelShuffle(2)
128
  self.final = nn.Conv2d(64, 3, 3, padding=1)
129
+
130
+ # Identity color preserving layer
131
  self.color_conv = nn.Conv2d(3, 3, 1)
132
  self._initialize_weights()
133
 
 
138
  if m.bias is not None:
139
  nn.init.zeros_(m.bias)
140
 
141
+ # Initialize color conv with identity matrix for color preservation
142
  with torch.no_grad():
143
  identity = torch.eye(3).reshape(3, 3, 1, 1)
144
  self.color_conv.weight.copy_(identity)
 
167
 
168
  # ====================== PROCESSING PIPELINE ======================
169
  def process_tile(model, tile, scale_factor):
170
+ # Preserve original for color reference
171
+ original_tile = tile.copy()
172
+
173
+ # Process with model
174
  tile_tensor = torch.tensor(tile/255.0, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
175
  with torch.no_grad():
176
  output = model(tile_tensor, scale_factor)
177
+
178
+ # Get raw output
179
+ raw_output = output.squeeze().permute(1,2,0).clamp(0,1).numpy() * 255
180
+
181
+ # Color correction
182
+ color_corrected = preserve_original_colors(original_tile, raw_output.astype(np.uint8))
183
+ return color_corrected
184
 
185
  def create_pyramid_weights(h, w):
186
  y = np.linspace(0, 1, h)
 
207
  out_y1, out_x1 = y1*scale_factor, x1*scale_factor
208
  out_y2, out_x2 = y2*scale_factor, x2*scale_factor
209
 
210
+ # Create weights for this tile
211
  weights = create_pyramid_weights(processed.shape[0], processed.shape[1])
212
 
213
  output[out_y1:out_y2, out_x1:out_x2] += processed * weights
 
250
 
251
  if image_np.shape[2] == 4:
252
  image_np = image_np[:,:,:3]
253
+
254
+ # Force grayscale for B&W images
255
+ is_grayscale = self._is_grayscale_image(image_np)
256
+ if is_grayscale:
257
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
258
+ image_np = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
259
+
260
  # Processing setup
261
  threads_used = self.energy_ctrl.adjust_processing(image_np.size)
262
  tile_size = self._calculate_optimal_tile_size(image_np)
263
 
264
+ # Save original for color reference
265
+ original_img = image_np.copy()
266
+
267
  # Core processing
268
  if max(image_np.shape[:2]) > tile_size:
269
  output = process_image_with_tiling(self.model, image_np, scale_factor, tile_size)
270
  else:
271
  output = process_tile(self.model, image_np, scale_factor)
272
 
273
+ # Final color correction
274
+ output = preserve_original_colors(
275
+ cv2.resize(original_img, (output.shape[1], output.shape[0])),
276
+ output
277
+ )
278
+
279
+ # For B&W images, ensure true grayscale output
280
+ if is_grayscale:
281
+ gray = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)
282
+ output = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
283
 
284
+ # Final edge enhancement
285
+ output = simple_edge_enhance(output)
 
 
 
 
 
286
 
287
  # Metrics
288
  metrics = {
 
290
  "input_resolution": f"{image_np.shape[1]}x{image_np.shape[0]}",
291
  "output_resolution": f"{output.shape[1]}x{output.shape[0]}",
292
  "threads_used": threads_used,
293
+ "tile_size": tile_size,
294
+ "color_preservation": "Active"
295
  }
296
 
297
  return Image.fromarray(output), metrics
298
+
299
+ def _is_grayscale_image(self, img, threshold=5):
300
+ """Detect if an image is effectively grayscale"""
301
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
302
+ b, g, r = cv2.split(img)
303
+
304
+ diff_r = np.abs(r.astype(np.float32) - gray.astype(np.float32))
305
+ diff_g = np.abs(g.astype(np.float32) - gray.astype(np.float32))
306
+ diff_b = np.abs(b.astype(np.float32) - gray.astype(np.float32))
307
+
308
+ total_diff = (np.mean(diff_r) + np.mean(diff_g) + np.mean(diff_b))/3
309
+ return total_diff < threshold
310
 
311
  # ====================== GRADIO INTERFACE ======================
312
  def create_interface():