Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import io | |
| class AutoWhiteBalance: | |
| """自動白平衡處理類""" | |
| def __init__(self): | |
| self.methods = { | |
| "灰世界算法 (Gray World)": "gray_world", | |
| "白斑算法 (White Patch)": "white_patch", | |
| "簡單平均 (Simple Average)": "simple_avg", | |
| "直方圖拉伸 (Histogram Stretch)": "histogram_stretch" | |
| } | |
| def pil_to_cv2(self, pil_image): | |
| """將PIL圖像轉換為OpenCV格式""" | |
| # 轉換為RGB(如果是RGBA則去除alpha通道) | |
| if pil_image.mode == 'RGBA': | |
| pil_image = pil_image.convert('RGB') | |
| elif pil_image.mode != 'RGB': | |
| pil_image = pil_image.convert('RGB') | |
| # 轉換為numpy array | |
| image_np = np.array(pil_image) | |
| # 轉換為BGR格式(OpenCV默認) | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| return image_bgr | |
| def cv2_to_pil(self, cv2_image): | |
| """將OpenCV格式轉換為PIL圖像""" | |
| # 轉換回RGB | |
| image_rgb = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB) | |
| # 轉換為PIL圖像 | |
| pil_image = Image.fromarray(image_rgb.astype(np.uint8)) | |
| return pil_image | |
| def gray_world_algorithm(self, image): | |
| """灰世界算法 - 假設圖像的平均顏色應該是灰色""" | |
| # 計算每個通道的平均值 | |
| avg_b = np.mean(image[:, :, 0]) | |
| avg_g = np.mean(image[:, :, 1]) | |
| avg_r = np.mean(image[:, :, 2]) | |
| # 計算整體平均亮度 | |
| avg_gray = (avg_b + avg_g + avg_r) / 3 | |
| # 避免除零錯誤 | |
| scale_b = avg_gray / max(avg_b, 1e-6) | |
| scale_g = avg_gray / max(avg_g, 1e-6) | |
| scale_r = avg_gray / max(avg_r, 1e-6) | |
| # 應用調整 | |
| result = image.astype(np.float32) | |
| result[:, :, 0] *= scale_b | |
| result[:, :, 1] *= scale_g | |
| result[:, :, 2] *= scale_r | |
| return result | |
| def white_patch_algorithm(self, image): | |
| """白斑算法 - 假設圖像中最亮的點應該是白色""" | |
| # 找到每個通道的最大值(排除極值) | |
| max_b = np.percentile(image[:, :, 0], 99) | |
| max_g = np.percentile(image[:, :, 1], 99) | |
| max_r = np.percentile(image[:, :, 2], 99) | |
| # 計算調整係數 | |
| scale_b = 255.0 / max(max_b, 1e-6) | |
| scale_g = 255.0 / max(max_g, 1e-6) | |
| scale_r = 255.0 / max(max_r, 1e-6) | |
| # 應用調整 | |
| result = image.astype(np.float32) | |
| result[:, :, 0] *= scale_b | |
| result[:, :, 1] *= scale_g | |
| result[:, :, 2] *= scale_r | |
| return result | |
| def simple_average_algorithm(self, image): | |
| """簡單平均算法 - 讓RGB三通道的平均值相等""" | |
| # 計算每個通道的平均值 | |
| avg_b = np.mean(image[:, :, 0]) | |
| avg_g = np.mean(image[:, :, 1]) | |
| avg_r = np.mean(image[:, :, 2]) | |
| # 以綠色通道為基準(人眼對綠色最敏感) | |
| reference = avg_g | |
| # 計算調整係數 | |
| scale_b = reference / max(avg_b, 1e-6) | |
| scale_g = 1.0 # 綠色通道不變 | |
| scale_r = reference / max(avg_r, 1e-6) | |
| # 應用調整 | |
| result = image.astype(np.float32) | |
| result[:, :, 0] *= scale_b | |
| result[:, :, 1] *= scale_g | |
| result[:, :, 2] *= scale_r | |
| return result | |
| def histogram_stretch_algorithm(self, image): | |
| """直方圖拉伸算法""" | |
| result = image.astype(np.float32) | |
| for i in range(3): # 對每個顏色通道 | |
| channel = result[:, :, i] | |
| # 計算1%和99%百分位數,忽略極值 | |
| p1 = np.percentile(channel, 1) | |
| p99 = np.percentile(channel, 99) | |
| # 拉伸到0-255範圍 | |
| if p99 > p1: | |
| channel = (channel - p1) * 255.0 / (p99 - p1) | |
| result[:, :, i] = np.clip(channel, 0, 255) | |
| return result | |
| def preserve_image_brightness(self, original, adjusted): | |
| """保持原始圖像的整體亮度""" | |
| try: | |
| # 計算原始圖像的亮度 | |
| original_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_BGR2LAB) | |
| original_brightness = np.mean(original_lab[:, :, 0]) | |
| # 計算調整後圖像的亮度 | |
| adjusted_lab = cv2.cvtColor(np.clip(adjusted, 0, 255).astype(np.uint8), cv2.COLOR_BGR2LAB) | |
| adjusted_brightness = np.mean(adjusted_lab[:, :, 0]) | |
| # 調整亮度 | |
| if adjusted_brightness > 1e-6: | |
| brightness_ratio = original_brightness / adjusted_brightness | |
| adjusted_lab[:, :, 0] = np.clip(adjusted_lab[:, :, 0] * brightness_ratio, 0, 255) | |
| # 轉換回BGR | |
| result = cv2.cvtColor(adjusted_lab, cv2.COLOR_LAB2BGR) | |
| return result.astype(np.float32) | |
| except: | |
| pass | |
| return adjusted | |
| def process_image(self, pil_image, method_name, strength, preserve_brightness, clip_values): | |
| """處理圖像的主函數""" | |
| if pil_image is None: | |
| return None, "請上傳一張圖片" | |
| try: | |
| # 轉換格式 | |
| cv2_image = self.pil_to_cv2(pil_image) | |
| original_image = cv2_image.copy() | |
| # 獲取算法名稱 | |
| method = self.methods.get(method_name, "gray_world") | |
| # 選擇算法 | |
| if method == "gray_world": | |
| adjusted = self.gray_world_algorithm(cv2_image) | |
| algorithm_info = "使用灰世界算法,假設圖像的平均顏色應該是中性灰" | |
| elif method == "white_patch": | |
| adjusted = self.white_patch_algorithm(cv2_image) | |
| algorithm_info = "使用白斑算法,假設圖像中最亮的區域應該是白色" | |
| elif method == "simple_avg": | |
| adjusted = self.simple_average_algorithm(cv2_image) | |
| algorithm_info = "使用簡單平均算法,平衡RGB三個通道的平均值" | |
| elif method == "histogram_stretch": | |
| adjusted = self.histogram_stretch_algorithm(cv2_image) | |
| algorithm_info = "使用直方圖拉伸算法,增強圖像對比度" | |
| else: | |
| adjusted = cv2_image.astype(np.float32) | |
| algorithm_info = "未知算法" | |
| # 應用強度調整 | |
| if strength != 1.0: | |
| adjusted = original_image.astype(np.float32) * (1.0 - strength) + adjusted * strength | |
| # 保持亮度 | |
| if preserve_brightness: | |
| adjusted = self.preserve_image_brightness(original_image, adjusted) | |
| # 裁剪數值範圍 | |
| if clip_values: | |
| adjusted = np.clip(adjusted, 0, 255) | |
| # 轉換回PIL格式 | |
| result_pil = self.cv2_to_pil(adjusted.astype(np.uint8)) | |
| # 生成處理信息 | |
| info = f"✅ 處理完成!\n算法:{algorithm_info}\n強度:{strength:.1f}\n保持亮度:{'是' if preserve_brightness else '否'}" | |
| return result_pil, info | |
| except Exception as e: | |
| return None, f"❌ 處理出錯:{str(e)}" | |
| # 創建白平衡處理器實例 | |
| wb_processor = AutoWhiteBalance() | |
| def process_white_balance(image, method, strength, preserve_brightness, clip_values): | |
| """Gradio接口函數""" | |
| return wb_processor.process_image(image, method, strength, preserve_brightness, clip_values) | |
| # 創建Gradio界面 | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="🎨 自動白平衡校正工具", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .image-container { | |
| max-height: 600px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🎨 給GPT用的自動白平衡校正工具 | |
| 上傳有色偏問題的圖片,選擇適合的算法自動校正白平衡,讓圖片恢復自然的色彩! | |
| 💡 **使用建議**: | |
| - 🟡 **暖色偏(偏黃)**:推薦使用「灰世界算法」 | |
| - 🔵 **冷色偏(偏藍)**:推薦使用「簡單平均」或「白斑算法」 | |
| - 🌈 **色彩不鮮豔**:推薦使用「直方圖拉伸,常用」 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 輸入區域 | |
| gr.Markdown("### 📤 上傳圖片") | |
| input_image = gr.Image( | |
| label="選擇要處理的圖片", | |
| type="pil", | |
| height=400 | |
| ) | |
| # 參數設置 | |
| gr.Markdown("### ⚙️ 調整參數") | |
| method_dropdown = gr.Dropdown( | |
| choices=list(wb_processor.methods.keys()), | |
| value="灰世界算法 (Gray World)", | |
| label="白平衡算法", | |
| info="選擇適合的白平衡校正算法" | |
| ) | |
| strength_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| label="調整強度", | |
| info="控制校正效果的強度(0=不調整,1=標準,2=強化)" | |
| ) | |
| preserve_brightness_checkbox = gr.Checkbox( | |
| value=True, | |
| label="保持原始亮度", | |
| info="保持圖片的整體明暗度不變" | |
| ) | |
| clip_values_checkbox = gr.Checkbox( | |
| value=True, | |
| label="裁剪數值範圍", | |
| info="避免過度調整造成的色彩異常" | |
| ) | |
| # 處理按鈕 | |
| process_btn = gr.Button( | |
| "🚀 開始處理", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| # 輸出區域 | |
| gr.Markdown("### 📥 處理結果") | |
| output_image = gr.Image( | |
| label="處理後的圖片", | |
| type="pil", | |
| height=400 | |
| ) | |
| # 處理信息 | |
| process_info = gr.Textbox( | |
| label="處理信息", | |
| lines=4, | |
| interactive=False | |
| ) | |
| # 下載提示 | |
| gr.Markdown(""" | |
| ### 💾 保存圖片 | |
| 右鍵點擊處理後的圖片,選擇「另存圖片為...」即可保存到本地。 | |
| """) | |
| # 綁定處理函數 | |
| process_btn.click( | |
| fn=process_white_balance, | |
| inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], | |
| outputs=[output_image, process_info] | |
| ) | |
| # 實時預覽(可選) | |
| for component in [method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox]: | |
| component.change( | |
| fn=process_white_balance, | |
| inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox], | |
| outputs=[output_image, process_info] | |
| ) | |
| # 添加說明 | |
| gr.Markdown(""" | |
| --- | |
| ### 📚 算法說明 | |
| - **灰世界算法**:適合校正整體色偏,特別是暖色偏(偏黃/橙) | |
| - **白斑算法**:適合有明顯白色或亮色區域的圖片 | |
| - **簡單平均**:溫和的校正方式,適合輕微色偏 | |
| - **直方圖拉伸**:增強對比度,讓色彩更鮮豔 | |
| ### 🔧 技術支持 | |
| 基於OpenCV和PIL開發,支持常見的圖片格式(JPG、PNG、WEBP等) | |
| """) | |
| return demo | |
| # 啟動應用 | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |