Update app.py
Browse filesyellow tint fix
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py -
|
| 2 |
import os
|
| 3 |
import cv2
|
| 4 |
import time
|
|
@@ -10,29 +10,67 @@ import torch.nn.functional as F
|
|
| 10 |
from PIL import Image
|
| 11 |
from functools import partial
|
| 12 |
|
| 13 |
-
# ====================== ARTIFACT MITIGATION
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def fix_chromatic_aberration(image):
|
| 15 |
-
"""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def apply_anti_ringing(img):
|
| 19 |
-
"""Reduce halo
|
| 20 |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 21 |
edges = cv2.Canny(gray, 100, 200)
|
| 22 |
-
dilated = cv2.dilate(edges, np.ones((
|
| 23 |
|
| 24 |
-
mask = cv2.GaussianBlur(dilated.astype(np.float32), (0,0), sigmaX=
|
| 25 |
mask = (mask / 255.0)[:,:,np.newaxis]
|
| 26 |
|
| 27 |
-
filtered = cv2.bilateralFilter(img, d=
|
| 28 |
-
return (img * (1-mask) + filtered * mask).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def hybrid_upscale(image, neural_result, blend_factor=0.
|
| 31 |
-
"""Blend neural and traditional upscaling"""
|
| 32 |
-
h, w = image.shape[:2]
|
| 33 |
traditional = cv2.resize(image, (neural_result.shape[1], neural_result.shape[0]),
|
| 34 |
interpolation=cv2.INTER_CUBIC)
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# ====================== MODEL ARCHITECTURE ======================
|
| 38 |
class SelfAttention(nn.Module):
|
|
@@ -88,6 +126,13 @@ class UltraEfficientSR(nn.Module):
|
|
| 88 |
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
| 89 |
if m.bias is not None:
|
| 90 |
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def forward(self, x, scale_factor=2):
|
| 93 |
x = self.initial(x)
|
|
@@ -141,8 +186,8 @@ def process_image_with_tiling(model, image, scale_factor, tile_size=256, overlap
|
|
| 141 |
out_y1, out_x1 = y1*scale_factor, x1*scale_factor
|
| 142 |
out_y2, out_x2 = y2*scale_factor, x2*scale_factor
|
| 143 |
|
| 144 |
-
weights
|
| 145 |
-
|
| 146 |
|
| 147 |
output[out_y1:out_y2, out_x1:out_x2] += processed * weights
|
| 148 |
weight_map[out_y1:out_y2, out_x1:out_x2] += weights
|
|
@@ -195,12 +240,21 @@ class CPUUpscaler:
|
|
| 195 |
else:
|
| 196 |
output = process_tile(self.model, image_np, scale_factor)
|
| 197 |
|
| 198 |
-
#
|
| 199 |
output = fix_chromatic_aberration(output)
|
| 200 |
output = apply_anti_ringing(output)
|
| 201 |
-
output = cv2.edgePreservingFilter(output, flags=cv2.NORMCONV_FILTER,
|
|
|
|
| 202 |
output = hybrid_upscale(image_np, output)
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
# Metrics
|
| 205 |
metrics = {
|
| 206 |
"processing_time": f"{time.time()-start_time:.2f}s",
|
|
|
|
| 1 |
+
# app.py - Fixed Implementation with Color Correction
|
| 2 |
import os
|
| 3 |
import cv2
|
| 4 |
import time
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
from functools import partial
|
| 12 |
|
| 13 |
+
# ====================== IMPROVED ARTIFACT MITIGATION ======================
|
| 14 |
+
def balance_color_channels(image):
|
| 15 |
+
"""Fix color channel imbalance causing yellow/green tint"""
|
| 16 |
+
b, g, r = cv2.split(image)
|
| 17 |
+
|
| 18 |
+
# Calculate channel means to detect imbalance
|
| 19 |
+
b_mean, g_mean, r_mean = np.mean(b), np.mean(g), np.mean(r)
|
| 20 |
+
|
| 21 |
+
# Detect yellow tint (low blue, high green/red)
|
| 22 |
+
if b_mean < (g_mean * 0.9) or b_mean < (r_mean * 0.9):
|
| 23 |
+
# Boost blue channel
|
| 24 |
+
b = np.clip(b * 1.2, 0, 255).astype(np.uint8)
|
| 25 |
+
# Reduce green slightly
|
| 26 |
+
g = np.clip(g * 0.9, 0, 255).astype(np.uint8)
|
| 27 |
+
|
| 28 |
+
return cv2.merge([b, g, r])
|
| 29 |
+
|
| 30 |
def fix_chromatic_aberration(image):
|
| 31 |
+
"""Improved color fringing reduction"""
|
| 32 |
+
# Reduce bilateral filter strength to preserve detail
|
| 33 |
+
filtered = cv2.bilateralFilter(image, d=3, sigmaColor=25, sigmaSpace=5)
|
| 34 |
+
# Fix color balance
|
| 35 |
+
return balance_color_channels(filtered)
|
| 36 |
|
| 37 |
def apply_anti_ringing(img):
|
| 38 |
+
"""Reduce halo artifacts around edges"""
|
| 39 |
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 40 |
edges = cv2.Canny(gray, 100, 200)
|
| 41 |
+
dilated = cv2.dilate(edges, np.ones((2,2), np.uint8)) # Smaller kernel
|
| 42 |
|
| 43 |
+
mask = cv2.GaussianBlur(dilated.astype(np.float32), (0,0), sigmaX=1.5)
|
| 44 |
mask = (mask / 255.0)[:,:,np.newaxis]
|
| 45 |
|
| 46 |
+
filtered = cv2.bilateralFilter(img, d=2, sigmaColor=20, sigmaSpace=2)
|
| 47 |
+
return (img * (1-mask*0.7) + filtered * (mask*0.7)).astype(np.uint8)
|
| 48 |
+
|
| 49 |
+
def color_normalize(image):
|
| 50 |
+
"""Normalize colors to prevent shifts"""
|
| 51 |
+
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
| 52 |
+
l, a, b = cv2.split(lab)
|
| 53 |
+
|
| 54 |
+
# Adjust a and b channels to prevent color shifts
|
| 55 |
+
a_mean, b_mean = np.mean(a), np.mean(b)
|
| 56 |
+
|
| 57 |
+
# Pull toward neutral color if strong cast is detected
|
| 58 |
+
if abs(a_mean - 128) > 10 or abs(b_mean - 128) > 10:
|
| 59 |
+
a = np.clip(a * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
|
| 60 |
+
b = np.clip(b * 0.85 + 128 * 0.15, 0, 255).astype(np.uint8)
|
| 61 |
+
|
| 62 |
+
return cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
|
| 63 |
|
| 64 |
+
def hybrid_upscale(image, neural_result, blend_factor=0.65): # Reduced influence
|
| 65 |
+
"""Blend neural and traditional upscaling with color preservation"""
|
|
|
|
| 66 |
traditional = cv2.resize(image, (neural_result.shape[1], neural_result.shape[0]),
|
| 67 |
interpolation=cv2.INTER_CUBIC)
|
| 68 |
+
|
| 69 |
+
# Blend with reduced neural influence to preserve original colors
|
| 70 |
+
blended = cv2.addWeighted(neural_result, blend_factor, traditional, 1-blend_factor, 0)
|
| 71 |
+
|
| 72 |
+
# Normalize colors
|
| 73 |
+
return color_normalize(blended)
|
| 74 |
|
| 75 |
# ====================== MODEL ARCHITECTURE ======================
|
| 76 |
class SelfAttention(nn.Module):
|
|
|
|
| 126 |
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
| 127 |
if m.bias is not None:
|
| 128 |
nn.init.zeros_(m.bias)
|
| 129 |
+
|
| 130 |
+
# Initialize color conv as identity for better color preservation
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
identity = torch.eye(3).reshape(3, 3, 1, 1)
|
| 133 |
+
self.color_conv.weight.copy_(identity)
|
| 134 |
+
if self.color_conv.bias is not None:
|
| 135 |
+
self.color_conv.bias.zero_()
|
| 136 |
|
| 137 |
def forward(self, x, scale_factor=2):
|
| 138 |
x = self.initial(x)
|
|
|
|
| 186 |
out_y1, out_x1 = y1*scale_factor, x1*scale_factor
|
| 187 |
out_y2, out_x2 = y2*scale_factor, x2*scale_factor
|
| 188 |
|
| 189 |
+
# Fix weights size to match processed tile shape
|
| 190 |
+
weights = create_pyramid_weights(processed.shape[0], processed.shape[1])
|
| 191 |
|
| 192 |
output[out_y1:out_y2, out_x1:out_x2] += processed * weights
|
| 193 |
weight_map[out_y1:out_y2, out_x1:out_x2] += weights
|
|
|
|
| 240 |
else:
|
| 241 |
output = process_tile(self.model, image_np, scale_factor)
|
| 242 |
|
| 243 |
+
# Enhanced artifact mitigation pipeline
|
| 244 |
output = fix_chromatic_aberration(output)
|
| 245 |
output = apply_anti_ringing(output)
|
| 246 |
+
output = cv2.edgePreservingFilter(output, flags=cv2.NORMCONV_FILTER,
|
| 247 |
+
sigma_s=40, sigma_r=0.3)
|
| 248 |
output = hybrid_upscale(image_np, output)
|
| 249 |
|
| 250 |
+
# Final color correction for extreme cases
|
| 251 |
+
hsv = cv2.cvtColor(output, cv2.COLOR_RGB2HSV)
|
| 252 |
+
h, s, v = cv2.split(hsv)
|
| 253 |
+
# Reduce saturation in yellow regions
|
| 254 |
+
yellow_mask = np.logical_and(h > 20, h < 35)
|
| 255 |
+
s[yellow_mask] = s[yellow_mask] * 0.7
|
| 256 |
+
output = cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2RGB)
|
| 257 |
+
|
| 258 |
# Metrics
|
| 259 |
metrics = {
|
| 260 |
"processing_time": f"{time.time()-start_time:.2f}s",
|