|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
class AutoWhiteBalance: |
|
|
"""自動白平衡處理類""" |
|
|
|
|
|
def __init__(self): |
|
|
self.methods = { |
|
|
"灰世界算法 (Gray World)": "gray_world", |
|
|
"白斑算法 (White Patch)": "white_patch", |
|
|
"簡單平均 (Simple Average)": "simple_avg", |
|
|
"直方圖拉伸 (Histogram Stretch)": "histogram_stretch" |
|
|
} |
|
|
|
|
|
def pil_to_cv2(self, pil_image): |
|
|
"""將PIL圖像轉換為OpenCV格式""" |
|
|
|
|
|
if pil_image.mode == 'RGBA': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
elif pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
image_np = np.array(pil_image) |
|
|
|
|
|
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
|
return image_bgr |
|
|
|
|
|
def cv2_to_pil(self, cv2_image): |
|
|
"""將OpenCV格式轉換為PIL圖像""" |
|
|
|
|
|
image_rgb = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
pil_image = Image.fromarray(image_rgb.astype(np.uint8)) |
|
|
return pil_image |
|
|
|
|
|
def gray_world_algorithm(self, image): |
|
|
"""灰世界算法 - 假設圖像的平均顏色應該是灰色""" |
|
|
|
|
|
avg_b = np.mean(image[:, :, 0]) |
|
|
avg_g = np.mean(image[:, :, 1]) |
|
|
avg_r = np.mean(image[:, :, 2]) |
|
|
|
|
|
|
|
|
avg_gray = (avg_b + avg_g + avg_r) / 3 |
|
|
|
|
|
|
|
|
scale_b = avg_gray / max(avg_b, 1e-6) |
|
|
scale_g = avg_gray / max(avg_g, 1e-6) |
|
|
scale_r = avg_gray / max(avg_r, 1e-6) |
|
|
|
|
|
|
|
|
result = image.astype(np.float32) |
|
|
result[:, :, 0] *= scale_b |
|
|
result[:, :, 1] *= scale_g |
|
|
result[:, :, 2] *= scale_r |
|
|
|
|
|
return result |
|
|
|
|
|
def white_patch_algorithm(self, image): |
|
|
"""白斑算法 - 假設圖像中最亮的點應該是白色""" |
|
|
|
|
|
max_b = np.percentile(image[:, :, 0], 99) |
|
|
max_g = np.percentile(image[:, :, 1], 99) |
|
|
max_r = np.percentile(image[:, :, 2], 99) |
|
|
|
|
|
|
|
|
scale_b = 255.0 / max(max_b, 1e-6) |
|
|
scale_g = 255.0 / max(max_g, 1e-6) |
|
|
scale_r = 255.0 / max(max_r, 1e-6) |
|
|
|
|
|
|
|
|
result = image.astype(np.float32) |
|
|
result[:, :, 0] *= scale_b |
|
|
result[:, :, 1] *= scale_g |
|
|
result[:, :, 2] *= scale_r |
|
|
|
|
|
return result |
|
|
|
|
|
def simple_average_algorithm(self, image): |
|
|
"""簡單平均算法 - 讓RGB三通道的平均值相等""" |
|
|
|
|
|
avg_b = np.mean(image[:, :, 0]) |
|
|
avg_g = np.mean(image[:, :, 1]) |
|
|
avg_r = np.mean(image[:, :, 2]) |
|
|
|
|
|
|
|
|
reference = avg_g |
|
|
|
|
|
|
|
|
scale_b = reference / max(avg_b, 1e-6) |
|
|
scale_g = 1.0 |
|
|
scale_r = reference / max(avg_r, 1e-6) |
|
|
|
|
|
|
|
|
result = image.astype(np.float32) |
|
|
result[:, :, 0] *= scale_b |
|
|
result[:, :, 1] *= scale_g |
|
|
result[:, :, 2] *= scale_r |
|
|
|
|
|
return result |
|
|
|
|
|
def histogram_stretch_algorithm(self, image): |
|
|
"""直方圖拉伸算法""" |
|
|
result = image.astype(np.float32) |
|
|
|
|
|
for i in range(3): |
|
|
channel = result[:, :, i] |
|
|
|
|
|
|
|
|
p1 = np.percentile(channel, 1) |
|
|
p99 = np.percentile(channel, 99) |
|
|
|
|
|
|
|
|
if p99 > p1: |
|
|
channel = (channel - p1) * 255.0 / (p99 - p1) |
|
|
result[:, :, i] = np.clip(channel, 0, 255) |
|
|
|
|
|
return result |
|
|
|
|
|
def preserve_image_brightness(self, original, adjusted): |
|
|
"""保持原始圖像的整體亮度""" |
|
|
try: |
|
|
|
|
|
original_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_BGR2LAB) |
|
|
original_brightness = np.mean(original_lab[:, :, 0]) |
|
|
|
|
|
|
|
|
adjusted_lab = cv2.cvtColor(np.clip(adjusted, 0, 255).astype(np.uint8), cv2.COLOR_BGR2LAB) |
|
|
adjusted_brightness = np.mean(adjusted_lab[:, :, 0]) |
|
|
|
|
|
|
|
|
if adjusted_brightness > 1e-6: |
|
|
brightness_ratio = original_brightness / adjusted_brightness |
|
|
adjusted_lab[:, :, 0] = np.clip(adjusted_lab[:, :, 0] * brightness_ratio, 0, 255) |
|
|
|
|
|
|
|
|
result = cv2.cvtColor(adjusted_lab, cv2.COLOR_LAB2BGR) |
|
|
return result.astype(np.float32) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return adjusted |
|
|
|
|
|
def process_image(self, pil_image, method_name, strength, preserve_brightness, clip_values): |
|
|
"""處理圖像的主函數""" |
|
|
if pil_image is None: |
|
|
return None, "請上傳一張圖片" |
|
|
|
|
|
try: |
|
|
|
|
|
cv2_image = self.pil_to_cv2(pil_image) |
|
|
original_image = cv2_image.copy() |
|
|
|
|
|
|
|
|
method = self.methods.get(method_name, "gray_world") |
|
|
|
|
|
|
|
|
if method == "gray_world": |
|
|
adjusted = self.gray_world_algorithm(cv2_image) |
|
|
algorithm_info = "使用灰世界算法,假設圖像的平均顏色應該是中性灰" |
|
|
elif method == "white_patch": |
|
|
adjusted = self.white_patch_algorithm(cv2_image) |
|
|
algorithm_info = "使用白斑算法,假設圖像中最亮的區域應該是白色" |
|
|
elif method == "simple_avg": |
|
|
adjusted = self.simple_average_algorithm(cv2_image) |
|
|
algorithm_info = "使用簡單平均算法,平衡RGB三個通道的平均值" |
|
|
elif method == "histogram_stretch": |
|
|
adjusted = self.histogram_stretch_algorithm(cv2_image) |
|
|
algorithm_info = "使用直方圖拉伸算法,增強圖像對比度" |
|
|
else: |
|
|
adjusted = cv2_image.astype(np.float32) |
|
|
algorithm_info = "未知算法" |
|
|
|
|
|
|
|
|
if strength != 1.0: |
|
|
adjusted = original_image.astype(np.float32) * (1.0 - strength) + adjusted * strength |
|
|
|
|
|
|
|
|
if preserve_brightness: |
|
|
adjusted = self.preserve_image_brightness(original_image, adjusted) |
|
|
|
|
|
|
|
|
if clip_values: |
|
|
adjusted = np.clip(adjusted, 0, 255) |
|
|
|
|
|
|
|
|
result_pil = self.cv2_to_pil(adjusted.astype(np.uint8)) |
|
|
|
|
|
|
|
|
info = f"✅ 處理完成!\n算法:{algorithm_info}\n強度:{strength:.1f}\n保持亮度:{'是' if preserve_brightness else '否'}" |
|
|
|
|
|
return result_pil, info |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"❌ 處理出錯:{str(e)}" |
|
|
|
|
|
|
|
|
wb_processor = AutoWhiteBalance() |
|
|
|
|
|
def process_white_balance(image, method, strength, preserve_brightness, clip_values): |
|
|
"""Gradio接口函數""" |
|
|
return wb_processor.process_image(image, method, strength, preserve_brightness, clip_values) |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
with gr.Blocks( |
|
|
title="🎨 自動白平衡校正工具", |
|
|
theme=gr.themes.Soft(), |
|
|
css=""" |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.image-container { |
|
|
max-height: 600px; |
|
|
} |
|
|
""" |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🎨 給GPT用的自動白平衡校正工具 |
|
|
|
|
|
上傳有色偏問題的圖片,選擇適合的算法自動校正白平衡,讓圖片恢復自然的色彩! |
|
|
|
|
|
💡 **使用建議**: |
|
|
- 🟡 **暖色偏(偏黃)**:推薦使用「灰世界算法」 |
|
|
- 🔵 **冷色偏(偏藍)**:推薦使用「簡單平均」或「白斑算法」 |
|
|
- 🌈 **色彩不鮮豔**:推薦使用「直方圖拉伸,常用」 |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("### 📤 上傳圖片") |
|
|
input_image = gr.Image( |
|
|
label="選擇要處理的圖片", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### ⚙️ 調整參數") |
|
|
|
|
|
method_dropdown = gr.Dropdown( |
|
|
choices=list(wb_processor.methods.keys()), |
|
|
value="灰世界算法 (Gray World)", |
|
|
label="白平衡算法", |
|
|
info="選擇適合的白平衡校正算法" |
|
|
) |
|
|
|
|
|
strength_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="調整強度", |
|
|
info="控制校正效果的強度(0=不調整,1=標準,2=強化)" |
|
|
) |
|
|
|
|
|
preserve_brightness_checkbox = gr.Checkbox( |
|
|
value=True, |
|
|
label="保持原始亮度", |
|
|
info="保持圖片的整體明暗度不變" |
|
|
) |
|
|
|
|
|
clip_values_checkbox = gr.Checkbox( |
|
|
value=True, |
|
|
label="裁剪數值範圍", |
|
|
info="避免過度調整造成的色彩異常" |
|
|
) |
|
|
|
|
|
|
|
|
process_btn = gr.Button( |
|
|
"🚀 開始處理", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("### 📥 處理結果") |
|
|
output_image = gr.Image( |
|
|
label="處理後的圖片", |
|
|
type="pil", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
|
|
|
process_info = gr.Textbox( |
|
|
label="處理信息", |
|
|
lines=4, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
### 💾 保存圖片 |
|
|
右鍵點擊處理後的圖片,選擇「另存圖片為...」即可保存到本地。 |
|
|
""") |
|
|
|
|
|
|
|
|
process_btn.click( |
|
|
fn=process_white_balance, |
|
|
inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], |
|
|
outputs=[output_image, process_info] |
|
|
) |
|
|
|
|
|
|
|
|
for component in [method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox]: |
|
|
component.change( |
|
|
fn=process_white_balance, |
|
|
inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], |
|
|
outputs=[output_image, process_info] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### 📚 算法說明 |
|
|
|
|
|
- **灰世界算法**:適合校正整體色偏,特別是暖色偏(偏黃/橙) |
|
|
- **白斑算法**:適合有明顯白色或亮色區域的圖片 |
|
|
- **簡單平均**:溫和的校正方式,適合輕微色偏 |
|
|
- **直方圖拉伸**:增強對比度,讓色彩更鮮豔 |
|
|
|
|
|
### 🔧 技術支持 |
|
|
基於OpenCV和PIL開發,支持常見的圖片格式(JPG、PNG、WEBP等) |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True |
|
|
) |