import gradio as gr import numpy as np import cv2 from PIL import Image import io class AutoWhiteBalance: """自動白平衡處理類""" def __init__(self): self.methods = { "灰世界算法 (Gray World)": "gray_world", "白斑算法 (White Patch)": "white_patch", "簡單平均 (Simple Average)": "simple_avg", "直方圖拉伸 (Histogram Stretch)": "histogram_stretch" } def pil_to_cv2(self, pil_image): """將PIL圖像轉換為OpenCV格式""" # 轉換為RGB(如果是RGBA則去除alpha通道) if pil_image.mode == 'RGBA': pil_image = pil_image.convert('RGB') elif pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # 轉換為numpy array image_np = np.array(pil_image) # 轉換為BGR格式(OpenCV默認) image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) return image_bgr def cv2_to_pil(self, cv2_image): """將OpenCV格式轉換為PIL圖像""" # 轉換回RGB image_rgb = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB) # 轉換為PIL圖像 pil_image = Image.fromarray(image_rgb.astype(np.uint8)) return pil_image def gray_world_algorithm(self, image): """灰世界算法 - 假設圖像的平均顏色應該是灰色""" # 計算每個通道的平均值 avg_b = np.mean(image[:, :, 0]) avg_g = np.mean(image[:, :, 1]) avg_r = np.mean(image[:, :, 2]) # 計算整體平均亮度 avg_gray = (avg_b + avg_g + avg_r) / 3 # 避免除零錯誤 scale_b = avg_gray / max(avg_b, 1e-6) scale_g = avg_gray / max(avg_g, 1e-6) scale_r = avg_gray / max(avg_r, 1e-6) # 應用調整 result = image.astype(np.float32) result[:, :, 0] *= scale_b result[:, :, 1] *= scale_g result[:, :, 2] *= scale_r return result def white_patch_algorithm(self, image): """白斑算法 - 假設圖像中最亮的點應該是白色""" # 找到每個通道的最大值(排除極值) max_b = np.percentile(image[:, :, 0], 99) max_g = np.percentile(image[:, :, 1], 99) max_r = np.percentile(image[:, :, 2], 99) # 計算調整係數 scale_b = 255.0 / max(max_b, 1e-6) scale_g = 255.0 / max(max_g, 1e-6) scale_r = 255.0 / max(max_r, 1e-6) # 應用調整 result = image.astype(np.float32) result[:, :, 0] *= scale_b result[:, :, 1] *= scale_g result[:, :, 2] *= scale_r return result def simple_average_algorithm(self, image): """簡單平均算法 - 讓RGB三通道的平均值相等""" # 計算每個通道的平均值 avg_b = np.mean(image[:, :, 0]) avg_g = np.mean(image[:, :, 1]) avg_r = np.mean(image[:, :, 2]) # 以綠色通道為基準(人眼對綠色最敏感) reference = avg_g # 計算調整係數 scale_b = reference / max(avg_b, 1e-6) scale_g = 1.0 # 綠色通道不變 scale_r = reference / max(avg_r, 1e-6) # 應用調整 result = image.astype(np.float32) result[:, :, 0] *= scale_b result[:, :, 1] *= scale_g result[:, :, 2] *= scale_r return result def histogram_stretch_algorithm(self, image): """直方圖拉伸算法""" result = image.astype(np.float32) for i in range(3): # 對每個顏色通道 channel = result[:, :, i] # 計算1%和99%百分位數,忽略極值 p1 = np.percentile(channel, 1) p99 = np.percentile(channel, 99) # 拉伸到0-255範圍 if p99 > p1: channel = (channel - p1) * 255.0 / (p99 - p1) result[:, :, i] = np.clip(channel, 0, 255) return result def preserve_image_brightness(self, original, adjusted): """保持原始圖像的整體亮度""" try: # 計算原始圖像的亮度 original_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_BGR2LAB) original_brightness = np.mean(original_lab[:, :, 0]) # 計算調整後圖像的亮度 adjusted_lab = cv2.cvtColor(np.clip(adjusted, 0, 255).astype(np.uint8), cv2.COLOR_BGR2LAB) adjusted_brightness = np.mean(adjusted_lab[:, :, 0]) # 調整亮度 if adjusted_brightness > 1e-6: brightness_ratio = original_brightness / adjusted_brightness adjusted_lab[:, :, 0] = np.clip(adjusted_lab[:, :, 0] * brightness_ratio, 0, 255) # 轉換回BGR result = cv2.cvtColor(adjusted_lab, cv2.COLOR_LAB2BGR) return result.astype(np.float32) except: pass return adjusted def process_image(self, pil_image, method_name, strength, preserve_brightness, clip_values): """處理圖像的主函數""" if pil_image is None: return None, "請上傳一張圖片" try: # 轉換格式 cv2_image = self.pil_to_cv2(pil_image) original_image = cv2_image.copy() # 獲取算法名稱 method = self.methods.get(method_name, "gray_world") # 選擇算法 if method == "gray_world": adjusted = self.gray_world_algorithm(cv2_image) algorithm_info = "使用灰世界算法,假設圖像的平均顏色應該是中性灰" elif method == "white_patch": adjusted = self.white_patch_algorithm(cv2_image) algorithm_info = "使用白斑算法,假設圖像中最亮的區域應該是白色" elif method == "simple_avg": adjusted = self.simple_average_algorithm(cv2_image) algorithm_info = "使用簡單平均算法,平衡RGB三個通道的平均值" elif method == "histogram_stretch": adjusted = self.histogram_stretch_algorithm(cv2_image) algorithm_info = "使用直方圖拉伸算法,增強圖像對比度" else: adjusted = cv2_image.astype(np.float32) algorithm_info = "未知算法" # 應用強度調整 if strength != 1.0: adjusted = original_image.astype(np.float32) * (1.0 - strength) + adjusted * strength # 保持亮度 if preserve_brightness: adjusted = self.preserve_image_brightness(original_image, adjusted) # 裁剪數值範圍 if clip_values: adjusted = np.clip(adjusted, 0, 255) # 轉換回PIL格式 result_pil = self.cv2_to_pil(adjusted.astype(np.uint8)) # 生成處理信息 info = f"✅ 處理完成!\n算法:{algorithm_info}\n強度:{strength:.1f}\n保持亮度:{'是' if preserve_brightness else '否'}" return result_pil, info except Exception as e: return None, f"❌ 處理出錯:{str(e)}" # 創建白平衡處理器實例 wb_processor = AutoWhiteBalance() def process_white_balance(image, method, strength, preserve_brightness, clip_values): """Gradio接口函數""" return wb_processor.process_image(image, method, strength, preserve_brightness, clip_values) # 創建Gradio界面 def create_interface(): with gr.Blocks( title="🎨 自動白平衡校正工具", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .image-container { max-height: 600px; } """ ) as demo: gr.Markdown(""" # 🎨 給GPT用的自動白平衡校正工具 上傳有色偏問題的圖片,選擇適合的算法自動校正白平衡,讓圖片恢復自然的色彩! 💡 **使用建議**: - 🟡 **暖色偏(偏黃)**:推薦使用「灰世界算法」 - 🔵 **冷色偏(偏藍)**:推薦使用「簡單平均」或「白斑算法」 - 🌈 **色彩不鮮豔**:推薦使用「直方圖拉伸,常用」 """) with gr.Row(): with gr.Column(scale=1): # 輸入區域 gr.Markdown("### 📤 上傳圖片") input_image = gr.Image( label="選擇要處理的圖片", type="pil", height=400 ) # 參數設置 gr.Markdown("### ⚙️ 調整參數") method_dropdown = gr.Dropdown( choices=list(wb_processor.methods.keys()), value="灰世界算法 (Gray World)", label="白平衡算法", info="選擇適合的白平衡校正算法" ) strength_slider = gr.Slider( minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="調整強度", info="控制校正效果的強度(0=不調整,1=標準,2=強化)" ) preserve_brightness_checkbox = gr.Checkbox( value=True, label="保持原始亮度", info="保持圖片的整體明暗度不變" ) clip_values_checkbox = gr.Checkbox( value=True, label="裁剪數值範圍", info="避免過度調整造成的色彩異常" ) # 處理按鈕 process_btn = gr.Button( "🚀 開始處理", variant="primary", size="lg" ) with gr.Column(scale=1): # 輸出區域 gr.Markdown("### 📥 處理結果") output_image = gr.Image( label="處理後的圖片", type="pil", height=400 ) # 處理信息 process_info = gr.Textbox( label="處理信息", lines=4, interactive=False ) # 下載提示 gr.Markdown(""" ### 💾 保存圖片 右鍵點擊處理後的圖片,選擇「另存圖片為...」即可保存到本地。 """) # 綁定處理函數 process_btn.click( fn=process_white_balance, inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], outputs=[output_image, process_info] ) # 實時預覽(可選) for component in [method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox]: component.change( fn=process_white_balance, inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], outputs=[output_image, process_info] ) # 添加說明 gr.Markdown(""" --- ### 📚 算法說明 - **灰世界算法**:適合校正整體色偏,特別是暖色偏(偏黃/橙) - **白斑算法**:適合有明顯白色或亮色區域的圖片 - **簡單平均**:溫和的校正方式,適合輕微色偏 - **直方圖拉伸**:增強對比度,讓色彩更鮮豔 ### 🔧 技術支持 基於OpenCV和PIL開發,支持常見的圖片格式(JPG、PNG、WEBP等) """) return demo # 啟動應用 if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )