colorizer / app.py
lingkoai's picture
Update app.py
efcb3a9 verified
import os
import numpy as np
import cv2
import gradio as gr
import onnxruntime as ort
from PIL import Image
# 模型檔案路徑 (請確保檔名為 colorizer.onnx)
MODEL_PATH = "colorizer.onnx"
def colorize_ddcolor(input_pil):
if input_pil is None:
return None
if not os.path.exists(MODEL_PATH):
return "錯誤:找不到 colorizer.onnx 檔案,請確認已上傳。"
try:
# 啟動 AI 引擎
session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
# 1. 預處理:將圖片轉為 RGB 並記錄原始大小
img_np = np.array(input_pil.convert('RGB'))
orig_h, orig_w = img_np.shape[:2]
# 2. 縮放與格式轉換
# 修正:很多 ONNX 模型內部其實是吃 BGR 順序,我們這裡做一次轉換
input_256 = cv2.resize(img_np, (256, 256))
input_256 = cv2.cvtColor(input_256, cv2.COLOR_RGB2BGR)
# 3. 歸一化:使用標準 AI 數值 (讓顏色填得進去的關鍵)
input_256 = input_256.astype(np.float32) / 255.0
# 4. 調整維度符合 AI 要求 (Batch, Channel, Height, Width)
input_tensor = np.transpose(input_256, (2, 0, 1)) # HWC -> CHW
input_tensor = np.expand_dims(input_tensor, axis=0) # 增加 Batch
# 5. 運行 AI 上色
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: input_tensor})
# 6. 後處理:將輸出的數據轉回圖片
output_np = np.squeeze(output)
output_np = np.transpose(output_np, (1, 2, 0)) # CHW -> HWC
# 修正顏色順序:將 BGR 轉回 RGB 供螢幕顯示
output_np = cv2.cvtColor(output_np, cv2.COLOR_BGR2RGB)
# 確保數值在 0-255 之間
output_np = np.clip(output_np, 0, 1)
output_np = (output_np * 255).astype(np.uint8)
# 7. 縮放回原始大小並輸出
result_img = cv2.resize(output_np, (orig_w, orig_h))
return Image.fromarray(result_img)
except Exception as e:
print(f"運行錯誤: {str(e)}")
return f"發生錯誤: {str(e)}"
# --- Gradio 介面部分 (這部分保持不變,確保你能正常操作) ---
with gr.Blocks(title="AI DDColor 上色工具") as demo:
gr.Markdown("# 🎨 DDColor 自動上色工具 (修正顏色版)")
gr.Markdown("這是一個專門修復『變藍色』與『填不到色』問題的版本。")
with gr.Row():
with gr.Column():
input_i = gr.Image(type="pil", label="上傳黑白圖片")
btn = gr.Button("開始生成彩色", variant="primary")
with gr.Column():
output_i = gr.Image(type="pil", label="AI 上色結果")
btn.click(colorize_ddcolor, inputs=input_i, outputs=output_i)
if __name__ == "__main__":
# 使用 7860 端口運行
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))