jgitsolutions commited on
Commit
3d00790
·
verified ·
1 Parent(s): fdfa1e6

Update app.py

Browse files

yellow tint fix

Files changed (1) hide show
  1. app.py +71 -17
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - Final Corrected Implementation
2
  import os
3
  import cv2
4
  import time
@@ -10,29 +10,67 @@ import torch.nn.functional as F
10
  from PIL import Image
11
  from functools import partial
12
 
13
- # ====================== ARTIFACT MITIGATION FUNCTIONS ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def fix_chromatic_aberration(image):
15
- """Align RGB channels to reduce color fringing"""
16
- return cv2.bilateralFilter(image, d=5, sigmaColor=50, sigmaSpace=10)
 
 
 
17
 
18
  def apply_anti_ringing(img):
19
- """Reduce halo/ringing artifacts around edges"""
20
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
21
  edges = cv2.Canny(gray, 100, 200)
22
- dilated = cv2.dilate(edges, np.ones((3,3), np.uint8))
23
 
24
- mask = cv2.GaussianBlur(dilated.astype(np.float32), (0,0), sigmaX=2)
25
  mask = (mask / 255.0)[:,:,np.newaxis]
26
 
27
- filtered = cv2.bilateralFilter(img, d=3, sigmaColor=25, sigmaSpace=3)
28
- return (img * (1-mask) + filtered * mask).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def hybrid_upscale(image, neural_result, blend_factor=0.8):
31
- """Blend neural and traditional upscaling"""
32
- h, w = image.shape[:2]
33
  traditional = cv2.resize(image, (neural_result.shape[1], neural_result.shape[0]),
34
  interpolation=cv2.INTER_CUBIC)
35
- return cv2.addWeighted(neural_result, blend_factor, traditional, 1-blend_factor, 0)
 
 
 
 
 
36
 
37
  # ====================== MODEL ARCHITECTURE ======================
38
  class SelfAttention(nn.Module):
@@ -88,6 +126,13 @@ class UltraEfficientSR(nn.Module):
88
  nn.init.kaiming_normal_(m.weight, mode='fan_out')
89
  if m.bias is not None:
90
  nn.init.zeros_(m.bias)
 
 
 
 
 
 
 
91
 
92
  def forward(self, x, scale_factor=2):
93
  x = self.initial(x)
@@ -141,8 +186,8 @@ def process_image_with_tiling(model, image, scale_factor, tile_size=256, overlap
141
  out_y1, out_x1 = y1*scale_factor, x1*scale_factor
142
  out_y2, out_x2 = y2*scale_factor, x2*scale_factor
143
 
144
- weights = create_pyramid_weights(tile.shape[0]*scale_factor,
145
- tile.shape[1]*scale_factor)
146
 
147
  output[out_y1:out_y2, out_x1:out_x2] += processed * weights
148
  weight_map[out_y1:out_y2, out_x1:out_x2] += weights
@@ -195,12 +240,21 @@ class CPUUpscaler:
195
  else:
196
  output = process_tile(self.model, image_np, scale_factor)
197
 
198
- # Artifact mitigation pipeline
199
  output = fix_chromatic_aberration(output)
200
  output = apply_anti_ringing(output)
201
- output = cv2.edgePreservingFilter(output, flags=cv2.NORMCONV_FILTER, sigma_s=60, sigma_r=0.4)
 
202
  output = hybrid_upscale(image_np, output)
203
 
 
 
 
 
 
 
 
 
204
  # Metrics
205
  metrics = {
206
  "processing_time": f"{time.time()-start_time:.2f}s",
 
1
+ # app.py - Fixed Implementation with Color Correction
2
  import os
3
  import cv2
4
  import time
 
10
  from PIL import Image
11
  from functools import partial
12
 
13
+ # ====================== IMPROVED ARTIFACT MITIGATION ======================
14
+ def balance_color_channels(image):
15
+ """Fix color channel imbalance causing yellow/green tint"""
16
+ b, g, r = cv2.split(image)
17
+
18
+ # Calculate channel means to detect imbalance
19
+ b_mean, g_mean, r_mean = np.mean(b), np.mean(g), np.mean(r)
20
+
21
+ # Detect yellow tint (low blue, high green/red)
22
+ if b_mean < (g_mean * 0.9) or b_mean < (r_mean * 0.9):
23
+ # Boost blue channel
24
+ b = np.clip(b * 1.2, 0, 255).astype(np.uint8)
25
+ # Reduce green slightly
26
+ g = np.clip(g * 0.9, 0, 255).astype(np.uint8)
27
+
28
+ return cv2.merge([b, g, r])
29
+
30
  def fix_chromatic_aberration(image):
31
+ """Improved color fringing reduction"""
32
+ # Reduce bilateral filter strength to preserve detail
33
+ filtered = cv2.bilateralFilter(image, d=3, sigmaColor=25, sigmaSpace=5)
34
+ # Fix color balance
35
+ return balance_color_channels(filtered)
36
 
37
  def apply_anti_ringing(img):
38
+ """Reduce halo artifacts around edges"""
39
  gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
40
  edges = cv2.Canny(gray, 100, 200)
41
+ dilated = cv2.dilate(edges, np.ones((2,2), np.uint8)) # Smaller kernel
42
 
43
+ mask = cv2.GaussianBlur(dilated.astype(np.float32), (0,0), sigmaX=1.5)
44
  mask = (mask / 255.0)[:,:,np.newaxis]
45
 
46
+ filtered = cv2.bilateralFilter(img, d=2, sigmaColor=20, sigmaSpace=2)
47
+ return (img * (1-mask*0.7) + filtered * (mask*0.7)).astype(np.uint8)
48
+
49
+ def color_normalize(image):
50
+ """Normalize colors to prevent shifts"""
51
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
52
+ l, a, b = cv2.split(lab)
53
+
54
+ # Adjust a and b channels to prevent color shifts
55
+ a_mean, b_mean = np.mean(a), np.mean(b)
56
+
57
+ # Pull toward neutral color if strong cast is detected
58
+ if abs(a_mean - 128) > 10 or abs(b_mean - 128) > 10:
59
+ a = np.clip(a * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
60
+ b = np.clip(b * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
61
+
62
+ return cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
63
 
64
+ def hybrid_upscale(image, neural_result, blend_factor=0.65): # Reduced influence
65
+ """Blend neural and traditional upscaling with color preservation"""
 
66
  traditional = cv2.resize(image, (neural_result.shape[1], neural_result.shape[0]),
67
  interpolation=cv2.INTER_CUBIC)
68
+
69
+ # Blend with reduced neural influence to preserve original colors
70
+ blended = cv2.addWeighted(neural_result, blend_factor, traditional, 1-blend_factor, 0)
71
+
72
+ # Normalize colors
73
+ return color_normalize(blended)
74
 
75
  # ====================== MODEL ARCHITECTURE ======================
76
  class SelfAttention(nn.Module):
 
126
  nn.init.kaiming_normal_(m.weight, mode='fan_out')
127
  if m.bias is not None:
128
  nn.init.zeros_(m.bias)
129
+
130
+ # Initialize color conv as identity for better color preservation
131
+ with torch.no_grad():
132
+ identity = torch.eye(3).reshape(3, 3, 1, 1)
133
+ self.color_conv.weight.copy_(identity)
134
+ if self.color_conv.bias is not None:
135
+ self.color_conv.bias.zero_()
136
 
137
  def forward(self, x, scale_factor=2):
138
  x = self.initial(x)
 
186
  out_y1, out_x1 = y1*scale_factor, x1*scale_factor
187
  out_y2, out_x2 = y2*scale_factor, x2*scale_factor
188
 
189
+ # Fix weights size to match processed tile shape
190
+ weights = create_pyramid_weights(processed.shape[0], processed.shape[1])
191
 
192
  output[out_y1:out_y2, out_x1:out_x2] += processed * weights
193
  weight_map[out_y1:out_y2, out_x1:out_x2] += weights
 
240
  else:
241
  output = process_tile(self.model, image_np, scale_factor)
242
 
243
+ # Enhanced artifact mitigation pipeline
244
  output = fix_chromatic_aberration(output)
245
  output = apply_anti_ringing(output)
246
+ output = cv2.edgePreservingFilter(output, flags=cv2.NORMCONV_FILTER,
247
+ sigma_s=40, sigma_r=0.3)
248
  output = hybrid_upscale(image_np, output)
249
 
250
+ # Final color correction for extreme cases
251
+ hsv = cv2.cvtColor(output, cv2.COLOR_RGB2HSV)
252
+ h, s, v = cv2.split(hsv)
253
+ # Reduce saturation in yellow regions
254
+ yellow_mask = np.logical_and(h > 20, h < 35)
255
+ s[yellow_mask] = s[yellow_mask] * 0.7
256
+ output = cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2RGB)
257
+
258
  # Metrics
259
  metrics = {
260
  "processing_time": f"{time.time()-start_time:.2f}s",