dseditor commited on
Commit
25f2cfb
·
verified ·
1 Parent(s): c9db092

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +353 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import io
6
+
7
+ class AutoWhiteBalance:
8
+ """自動白平衡處理類"""
9
+
10
+ def __init__(self):
11
+ self.methods = {
12
+ "灰世界算法 (Gray World)": "gray_world",
13
+ "白斑算法 (White Patch)": "white_patch",
14
+ "簡單平均 (Simple Average)": "simple_avg",
15
+ "直方圖拉伸 (Histogram Stretch)": "histogram_stretch"
16
+ }
17
+
18
+ def pil_to_cv2(self, pil_image):
19
+ """將PIL圖像轉換為OpenCV格式"""
20
+ # 轉換為RGB(如果是RGBA則去除alpha通道)
21
+ if pil_image.mode == 'RGBA':
22
+ pil_image = pil_image.convert('RGB')
23
+ elif pil_image.mode != 'RGB':
24
+ pil_image = pil_image.convert('RGB')
25
+
26
+ # 轉換為numpy array
27
+ image_np = np.array(pil_image)
28
+ # 轉換為BGR格式(OpenCV默認)
29
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
30
+ return image_bgr
31
+
32
+ def cv2_to_pil(self, cv2_image):
33
+ """將OpenCV格式轉換為PIL圖像"""
34
+ # 轉換回RGB
35
+ image_rgb = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
36
+ # 轉換為PIL圖像
37
+ pil_image = Image.fromarray(image_rgb.astype(np.uint8))
38
+ return pil_image
39
+
40
+ def gray_world_algorithm(self, image):
41
+ """灰世界算法 - 假設圖像的平均顏色應該是灰色"""
42
+ # 計算每個通道的平均值
43
+ avg_b = np.mean(image[:, :, 0])
44
+ avg_g = np.mean(image[:, :, 1])
45
+ avg_r = np.mean(image[:, :, 2])
46
+
47
+ # 計算整體平均亮度
48
+ avg_gray = (avg_b + avg_g + avg_r) / 3
49
+
50
+ # 避免除零錯誤
51
+ scale_b = avg_gray / max(avg_b, 1e-6)
52
+ scale_g = avg_gray / max(avg_g, 1e-6)
53
+ scale_r = avg_gray / max(avg_r, 1e-6)
54
+
55
+ # 應用調整
56
+ result = image.astype(np.float32)
57
+ result[:, :, 0] *= scale_b
58
+ result[:, :, 1] *= scale_g
59
+ result[:, :, 2] *= scale_r
60
+
61
+ return result
62
+
63
+ def white_patch_algorithm(self, image):
64
+ """白斑算法 - 假設圖像中最亮的點應該是白色"""
65
+ # 找到每個通道的最大值(排除極值)
66
+ max_b = np.percentile(image[:, :, 0], 99)
67
+ max_g = np.percentile(image[:, :, 1], 99)
68
+ max_r = np.percentile(image[:, :, 2], 99)
69
+
70
+ # 計算調整係數
71
+ scale_b = 255.0 / max(max_b, 1e-6)
72
+ scale_g = 255.0 / max(max_g, 1e-6)
73
+ scale_r = 255.0 / max(max_r, 1e-6)
74
+
75
+ # 應用調整
76
+ result = image.astype(np.float32)
77
+ result[:, :, 0] *= scale_b
78
+ result[:, :, 1] *= scale_g
79
+ result[:, :, 2] *= scale_r
80
+
81
+ return result
82
+
83
+ def simple_average_algorithm(self, image):
84
+ """簡單平均算法 - 讓RGB三通道的平均值相等"""
85
+ # 計算每個通道的平均值
86
+ avg_b = np.mean(image[:, :, 0])
87
+ avg_g = np.mean(image[:, :, 1])
88
+ avg_r = np.mean(image[:, :, 2])
89
+
90
+ # 以綠色通道為基準(人眼對綠色最敏感)
91
+ reference = avg_g
92
+
93
+ # 計算調整係數
94
+ scale_b = reference / max(avg_b, 1e-6)
95
+ scale_g = 1.0 # 綠色通道不變
96
+ scale_r = reference / max(avg_r, 1e-6)
97
+
98
+ # 應用調整
99
+ result = image.astype(np.float32)
100
+ result[:, :, 0] *= scale_b
101
+ result[:, :, 1] *= scale_g
102
+ result[:, :, 2] *= scale_r
103
+
104
+ return result
105
+
106
+ def histogram_stretch_algorithm(self, image):
107
+ """直方圖拉伸算法"""
108
+ result = image.astype(np.float32)
109
+
110
+ for i in range(3): # 對每個顏色通道
111
+ channel = result[:, :, i]
112
+
113
+ # 計算1%和99%百分位數,忽略極值
114
+ p1 = np.percentile(channel, 1)
115
+ p99 = np.percentile(channel, 99)
116
+
117
+ # 拉伸到0-255範圍
118
+ if p99 > p1:
119
+ channel = (channel - p1) * 255.0 / (p99 - p1)
120
+ result[:, :, i] = np.clip(channel, 0, 255)
121
+
122
+ return result
123
+
124
+ def preserve_image_brightness(self, original, adjusted):
125
+ """保持原始圖像的整體亮度"""
126
+ try:
127
+ # 計算原始圖像的亮度
128
+ original_lab = cv2.cvtColor(original.astype(np.uint8), cv2.COLOR_BGR2LAB)
129
+ original_brightness = np.mean(original_lab[:, :, 0])
130
+
131
+ # 計算調整後圖像的亮度
132
+ adjusted_lab = cv2.cvtColor(np.clip(adjusted, 0, 255).astype(np.uint8), cv2.COLOR_BGR2LAB)
133
+ adjusted_brightness = np.mean(adjusted_lab[:, :, 0])
134
+
135
+ # 調整亮度
136
+ if adjusted_brightness > 1e-6:
137
+ brightness_ratio = original_brightness / adjusted_brightness
138
+ adjusted_lab[:, :, 0] = np.clip(adjusted_lab[:, :, 0] * brightness_ratio, 0, 255)
139
+
140
+ # 轉換回BGR
141
+ result = cv2.cvtColor(adjusted_lab, cv2.COLOR_LAB2BGR)
142
+ return result.astype(np.float32)
143
+ except:
144
+ pass
145
+
146
+ return adjusted
147
+
148
+ def process_image(self, pil_image, method_name, strength, preserve_brightness, clip_values):
149
+ """處理圖像的主函數"""
150
+ if pil_image is None:
151
+ return None, "請上傳一張圖片"
152
+
153
+ try:
154
+ # 轉換格式
155
+ cv2_image = self.pil_to_cv2(pil_image)
156
+ original_image = cv2_image.copy()
157
+
158
+ # 獲取算法名稱
159
+ method = self.methods.get(method_name, "gray_world")
160
+
161
+ # 選擇算法
162
+ if method == "gray_world":
163
+ adjusted = self.gray_world_algorithm(cv2_image)
164
+ algorithm_info = "使用灰世界算法,假設圖像的平均顏色應該是中性灰"
165
+ elif method == "white_patch":
166
+ adjusted = self.white_patch_algorithm(cv2_image)
167
+ algorithm_info = "使用白斑算法,假設圖像中最亮的區域應該是白色"
168
+ elif method == "simple_avg":
169
+ adjusted = self.simple_average_algorithm(cv2_image)
170
+ algorithm_info = "使用簡單平均算法,平衡RGB三個通道的平均值"
171
+ elif method == "histogram_stretch":
172
+ adjusted = self.histogram_stretch_algorithm(cv2_image)
173
+ algorithm_info = "使用直方圖拉伸算法,增強圖像對比度"
174
+ else:
175
+ adjusted = cv2_image.astype(np.float32)
176
+ algorithm_info = "未知算法"
177
+
178
+ # 應用強度調整
179
+ if strength != 1.0:
180
+ adjusted = original_image.astype(np.float32) * (1.0 - strength) + adjusted * strength
181
+
182
+ # 保持亮度
183
+ if preserve_brightness:
184
+ adjusted = self.preserve_image_brightness(original_image, adjusted)
185
+
186
+ # 裁剪數值範圍
187
+ if clip_values:
188
+ adjusted = np.clip(adjusted, 0, 255)
189
+
190
+ # 轉換回PIL格式
191
+ result_pil = self.cv2_to_pil(adjusted.astype(np.uint8))
192
+
193
+ # 生成處理信息
194
+ info = f"✅ 處理完成!\n算法:{algorithm_info}\n強度:{strength:.1f}\n保持亮度:{'是' if preserve_brightness else '否'}"
195
+
196
+ return result_pil, info
197
+
198
+ except Exception as e:
199
+ return None, f"❌ 處理出錯:{str(e)}"
200
+
201
+ # 創建白平衡處理器實例
202
+ wb_processor = AutoWhiteBalance()
203
+
204
+ def process_white_balance(image, method, strength, preserve_brightness, clip_values):
205
+ """Gradio接口函數"""
206
+ return wb_processor.process_image(image, method, strength, preserve_brightness, clip_values)
207
+
208
+ # 創建Gradio界面
209
+ def create_interface():
210
+ with gr.Blocks(
211
+ title="🎨 自動白平衡校正工具",
212
+ theme=gr.themes.Soft(),
213
+ css="""
214
+ .gradio-container {
215
+ max-width: 1200px !important;
216
+ }
217
+ .image-container {
218
+ max-height: 600px;
219
+ }
220
+ """
221
+ ) as demo:
222
+
223
+ gr.Markdown("""
224
+ # 🎨 自動白平衡校正工具
225
+
226
+ 上傳有色偏問題的圖片,選擇適合的算法自動校正白平衡,讓圖片恢復自然的色彩!
227
+
228
+ 💡 **使用建議**:
229
+ - 🟡 **暖色偏(偏黃)**:推薦使用「灰世界算法」
230
+ - 🔵 **冷色偏(偏藍)**:推薦使用「簡單平均」或「白斑算法」
231
+ - 🌈 **色彩不鮮豔**:推薦使用「直方圖拉伸」
232
+ """)
233
+
234
+ with gr.Row():
235
+ with gr.Column(scale=1):
236
+ # 輸入區域
237
+ gr.Markdown("### 📤 上傳圖片")
238
+ input_image = gr.Image(
239
+ label="選擇要處理的圖片",
240
+ type="pil",
241
+ height=400
242
+ )
243
+
244
+ # 參數設置
245
+ gr.Markdown("### ⚙️ 調整參數")
246
+
247
+ method_dropdown = gr.Dropdown(
248
+ choices=list(wb_processor.methods.keys()),
249
+ value="灰世界算法 (Gray World)",
250
+ label="白平衡算法",
251
+ info="選擇適合的白平衡校正算法"
252
+ )
253
+
254
+ strength_slider = gr.Slider(
255
+ minimum=0.0,
256
+ maximum=2.0,
257
+ value=1.0,
258
+ step=0.1,
259
+ label="調整強度",
260
+ info="控制校正效果的強度(0=不調整,1=標準,2=強化)"
261
+ )
262
+
263
+ preserve_brightness_checkbox = gr.Checkbox(
264
+ value=True,
265
+ label="保持原始亮度",
266
+ info="保持圖片的整體明暗度不變"
267
+ )
268
+
269
+ clip_values_checkbox = gr.Checkbox(
270
+ value=True,
271
+ label="裁剪數值範圍",
272
+ info="避免過度調整造成的色彩異常"
273
+ )
274
+
275
+ # 處理按鈕
276
+ process_btn = gr.Button(
277
+ "🚀 開始處理",
278
+ variant="primary",
279
+ size="lg"
280
+ )
281
+
282
+ with gr.Column(scale=1):
283
+ # 輸出區域
284
+ gr.Markdown("### 📥 處理結果")
285
+ output_image = gr.Image(
286
+ label="處理後的圖片",
287
+ type="pil",
288
+ height=400
289
+ )
290
+
291
+ # 處理信息
292
+ process_info = gr.Textbox(
293
+ label="處理信息",
294
+ lines=4,
295
+ interactive=False
296
+ )
297
+
298
+ # 下載提示
299
+ gr.Markdown("""
300
+ ### 💾 保存圖片
301
+ 右鍵點擊處理後的圖片,選擇「另存圖片為...」即可保存到本地。
302
+ """)
303
+
304
+ # 示例圖片
305
+ gr.Markdown("### 📸 示例圖片")
306
+ gr.Examples(
307
+ examples=[
308
+ ["example1.jpg", "灰世界算法 (Gray World)", 1.0, True, True],
309
+ ["example2.jpg", "白斑算法 (White Patch)", 0.8, True, True],
310
+ ],
311
+ inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox],
312
+ outputs=[output_image, process_info]
313
+ )
314
+
315
+ # 綁定處理函數
316
+ process_btn.click(
317
+ fn=process_white_balance,
318
+ inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox],
319
+ outputs=[output_image, process_info]
320
+ )
321
+
322
+ # 實時預覽(可選)
323
+ for component in [method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox]:
324
+ component.change(
325
+ fn=process_white_balance,
326
+ inputs=[input_image, method_dropdown, strength_slider, preserve_brightness_checkbox, clip_values_checkbox],
327
+ outputs=[output_image, process_info]
328
+ )
329
+
330
+ # 添加說明
331
+ gr.Markdown("""
332
+ ---
333
+ ### 📚 算法說明
334
+
335
+ - **灰世界算法**:適合校正整體色偏,特別是暖色偏(偏黃/橙)
336
+ - **白斑算法**:適合有明顯白色或亮色區域的圖片
337
+ - **簡單平均**:溫和的校正方式,適合輕微色偏
338
+ - **直方圖拉伸**:增強對比度,讓色彩更鮮豔
339
+
340
+ ### 🔧 技術支持
341
+ 基於OpenCV和PIL開發,支持常見的圖片格式(JPG、PNG、WEBP等)
342
+ """)
343
+
344
+ return demo
345
+
346
+ # 啟動應用
347
+ if __name__ == "__main__":
348
+ demo = create_interface()
349
+ demo.launch(
350
+ server_name="0.0.0.0",
351
+ server_port=7860,
352
+ share=True
353
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ opencv-python-headless>=4.8.0
3
+ Pillow>=9.0.0
4
+ numpy>=1.21.0