dseditor's picture
修正檔案
450dbd4 verified
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import io
class AutoWhiteBalance:
"""自動白平衡處理類"""
def __init__(self):
self.methods = {
"灰世界算法 (Gray World)": "gray_world",
"白斑算法 (White Patch)": "white_patch",
"簡單平均 (Simple Average)": "simple_avg",
"直方圖拉伸 (Histogram Stretch)": "histogram_stretch"
}
def pil_to_cv2(self, pil_image):
"""將PIL圖像轉換為OpenCV格式"""
# 轉換為RGB(如果是RGBA則去除alpha通道)
if pil_image.mode == 'RGBA':
pil_image = pil_image.convert('RGB')
elif pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
# 轉換為numpy array
image_np = np.array(pil_image)
# 轉換為BGR格式(OpenCV默認)
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
return image_bgr
def cv2_to_pil(self, cv2_image):
"""將OpenCV格式轉換為PIL圖像"""
# 轉換回RGB
image_rgb = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
# 轉換為PIL圖像
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
return pil_image
def gray_world_algorithm(self, image):
"""灰世界算法 - 假設圖像的平均顏色應該是灰色"""
# 計算每個通道的平均值
avg_b = np.mean(image[:, :, 0])
avg_g = np.mean(image[:, :, 1])
avg_r = np.mean(image[:, :, 2])
# 計算整體平均亮度
avg_gray = (avg_b + avg_g + avg_r) / 3
# 避免除零錯誤
scale_b = avg_gray / max(avg_b, 1e-6)
scale_g = avg_gray / max(avg_g, 1e-6)
scale_r = avg_gray / max(avg_r, 1e-6)
# 應用調整
result = image.astype(np.float32)
result[:, :, 0] *= scale_b
result[:, :, 1] *= scale_g
result[:, :, 2] *= scale_r
return result
def white_patch_algorithm(self, image):
"""白斑算法 - 假設圖像中最亮的點應該是白色"""
# 找到每個通道的最大值(排除極值)
max_b = np.percentile(image[:, :, 0], 99)
max_g = np.percentile(image[:, :, 1], 99)
max_r = np.percentile(image[:, :, 2], 99)
# 計算調整係數
scale_b = 255.0 / max(max_b, 1e-6)
scale_g = 255.0 / max(max_g, 1e-6)
scale_r = 255.0 / max(max_r, 1e-6)
# 應用調整
result = image.astype(np.float32)
result[:, :, 0] *= scale_b
result[:, :, 1] *= scale_g
result[:, :, 2] *= scale_r
return result
def simple_average_algorithm(self, image):
"""簡單平均算法 - 讓RGB三通道的平均值相等"""
# 計算每個通道的平均值
avg_b = np.mean(image[:, :, 0])
avg_g = np.mean(image[:, :, 1])
avg_r = np.mean(image[:, :, 2])
# 以綠色通道為基準(人眼對綠色最敏感)
reference = avg_g
# 計算調整係數
scale_b = reference / max(avg_b, 1e-6)
scale_g = 1.0 # 綠色通道不變
scale_r = reference / max(avg_r, 1e-6)
# 應用調整
result = image.astype(np.float32)
result[:, :, 0] *= scale_b
result[:, :, 1] *= scale_g
result[:, :, 2] *= scale_r
return result
def histogram_stretch_algorithm(self, image):
"""直方圖拉伸算法"""
result = image.astype(np.float32)
for i in range(3): # 對每個顏色通道
channel = result[:, :, i]
# 計算1%和99%百分位數,忽略極值
p1 = np.percentile(channel, 1)
p99 = np.percentile(channel, 99)
# 拉伸到0-255範圍
if p99 > p1:
channel = (channel - p1) * 255.0 / (p99 - p1)
result[:, :, i] = np.clip(channel, 0, 255)
return result
def preserve_image_brightness(self, original, adjusted):
"""保持原始圖像的整體亮度"""
try:
# 計算原始圖像的亮度
original_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_BGR2LAB)
original_brightness = np.mean(original_lab[:, :, 0])
# 計算調整後圖像的亮度
adjusted_lab = cv2.cvtColor(np.clip(adjusted, 0, 255).astype(np.uint8), cv2.COLOR_BGR2LAB)
adjusted_brightness = np.mean(adjusted_lab[:, :, 0])
# 調整亮度
if adjusted_brightness > 1e-6:
brightness_ratio = original_brightness / adjusted_brightness
adjusted_lab[:, :, 0] = np.clip(adjusted_lab[:, :, 0] * brightness_ratio, 0, 255)
# 轉換回BGR
result = cv2.cvtColor(adjusted_lab, cv2.COLOR_LAB2BGR)
return result.astype(np.float32)
except:
pass
return adjusted
def process_image(self, pil_image, method_name, strength, preserve_brightness, clip_values):
"""處理圖像的主函數"""
if pil_image is None:
return None, "請上傳一張圖片"
try:
# 轉換格式
cv2_image = self.pil_to_cv2(pil_image)
original_image = cv2_image.copy()
# 獲取算法名稱
method = self.methods.get(method_name, "gray_world")
# 選擇算法
if method == "gray_world":
adjusted = self.gray_world_algorithm(cv2_image)
algorithm_info = "使用灰世界算法,假設圖像的平均顏色應該是中性灰"
elif method == "white_patch":
adjusted = self.white_patch_algorithm(cv2_image)
algorithm_info = "使用白斑算法,假設圖像中最亮的區域應該是白色"
elif method == "simple_avg":
adjusted = self.simple_average_algorithm(cv2_image)
algorithm_info = "使用簡單平均算法,平衡RGB三個通道的平均值"
elif method == "histogram_stretch":
adjusted = self.histogram_stretch_algorithm(cv2_image)
algorithm_info = "使用直方圖拉伸算法,增強圖像對比度"
else:
adjusted = cv2_image.astype(np.float32)
algorithm_info = "未知算法"
# 應用強度調整
if strength != 1.0:
adjusted = original_image.astype(np.float32) * (1.0 - strength) + adjusted * strength
# 保持亮度
if preserve_brightness:
adjusted = self.preserve_image_brightness(original_image, adjusted)
# 裁剪數值範圍
if clip_values:
adjusted = np.clip(adjusted, 0, 255)
# 轉換回PIL格式
result_pil = self.cv2_to_pil(adjusted.astype(np.uint8))
# 生成處理信息
info = f"✅ 處理完成!\n算法:{algorithm_info}\n強度:{strength:.1f}\n保持亮度:{'是' if preserve_brightness else '否'}"
return result_pil, info
except Exception as e:
return None, f"❌ 處理出錯:{str(e)}"
# 創建白平衡處理器實例
wb_processor = AutoWhiteBalance()
def process_white_balance(image, method, strength, preserve_brightness, clip_values):
"""Gradio接口函數"""
return wb_processor.process_image(image, method, strength, preserve_brightness, clip_values)
# 創建Gradio界面
def create_interface():
with gr.Blocks(
title="🎨 自動白平衡校正工具",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.image-container {
max-height: 600px;
}
"""
) as demo:
gr.Markdown("""
# 🎨 給GPT用的自動白平衡校正工具
上傳有色偏問題的圖片,選擇適合的算法自動校正白平衡,讓圖片恢復自然的色彩!
💡 **使用建議**:
- 🟡 **暖色偏(偏黃)**:推薦使用「灰世界算法」
- 🔵 **冷色偏(偏藍)**:推薦使用「簡單平均」或「白斑算法」
- 🌈 **色彩不鮮豔**:推薦使用「直方圖拉伸,常用」
""")
with gr.Row():
with gr.Column(scale=1):
# 輸入區域
gr.Markdown("### 📤 上傳圖片")
input_image = gr.Image(
label="選擇要處理的圖片",
type="pil",
height=400
)
# 參數設置
gr.Markdown("### ⚙️ 調整參數")
method_dropdown = gr.Dropdown(
choices=list(wb_processor.methods.keys()),
value="灰世界算法 (Gray World)",
label="白平衡算法",
info="選擇適合的白平衡校正算法"
)
strength_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.1,
label="調整強度",
info="控制校正效果的強度(0=不調整,1=標準,2=強化)"
)
preserve_brightness_checkbox = gr.Checkbox(
value=True,
label="保持原始亮度",
info="保持圖片的整體明暗度不變"
)
clip_values_checkbox = gr.Checkbox(
value=True,
label="裁剪數值範圍",
info="避免過度調整造成的色彩異常"
)
# 處理按鈕
process_btn = gr.Button(
"🚀 開始處理",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
# 輸出區域
gr.Markdown("### 📥 處理結果")
output_image = gr.Image(
label="處理後的圖片",
type="pil",
height=400
)
# 處理信息
process_info = gr.Textbox(
label="處理信息",
lines=4,
interactive=False
)
# 下載提示
gr.Markdown("""
### 💾 保存圖片
右鍵點擊處理後的圖片,選擇「另存圖片為...」即可保存到本地。
""")
# 綁定處理函數
process_btn.click(
fn=process_white_balance,
inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox],
outputs=[output_image, process_info]
)
# 實時預覽(可選)
for component in [method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox]:
component.change(
fn=process_white_balance,
inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox],
outputs=[output_image, process_info]
)
# 添加說明
gr.Markdown("""
---
### 📚 算法說明
- **灰世界算法**:適合校正整體色偏,特別是暖色偏(偏黃/橙)
- **白斑算法**:適合有明顯白色或亮色區域的圖片
- **簡單平均**:溫和的校正方式,適合輕微色偏
- **直方圖拉伸**:增強對比度,讓色彩更鮮豔
### 🔧 技術支持
基於OpenCV和PIL開發,支持常見的圖片格式(JPG、PNG、WEBP等)
""")
return demo
# 啟動應用
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)