|
|
import os |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import onnxruntime as ort |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
MODEL_PATH = "colorizer.onnx" |
|
|
|
|
|
def colorize_ddcolor(input_pil): |
|
|
if input_pil is None: |
|
|
return None |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
return "錯誤:找不到 colorizer.onnx 檔案,請確認已上傳。" |
|
|
|
|
|
try: |
|
|
|
|
|
session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) |
|
|
|
|
|
|
|
|
img_np = np.array(input_pil.convert('RGB')) |
|
|
orig_h, orig_w = img_np.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
input_256 = cv2.resize(img_np, (256, 256)) |
|
|
input_256 = cv2.cvtColor(input_256, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
input_256 = input_256.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
input_tensor = np.transpose(input_256, (2, 0, 1)) |
|
|
input_tensor = np.expand_dims(input_tensor, axis=0) |
|
|
|
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
|
output = session.run(None, {input_name: input_tensor}) |
|
|
|
|
|
|
|
|
output_np = np.squeeze(output) |
|
|
output_np = np.transpose(output_np, (1, 2, 0)) |
|
|
|
|
|
|
|
|
output_np = cv2.cvtColor(output_np, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
output_np = np.clip(output_np, 0, 1) |
|
|
output_np = (output_np * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
result_img = cv2.resize(output_np, (orig_w, orig_h)) |
|
|
return Image.fromarray(result_img) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"運行錯誤: {str(e)}") |
|
|
return f"發生錯誤: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI DDColor 上色工具") as demo: |
|
|
gr.Markdown("# 🎨 DDColor 自動上色工具 (修正顏色版)") |
|
|
gr.Markdown("這是一個專門修復『變藍色』與『填不到色』問題的版本。") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_i = gr.Image(type="pil", label="上傳黑白圖片") |
|
|
btn = gr.Button("開始生成彩色", variant="primary") |
|
|
with gr.Column(): |
|
|
output_i = gr.Image(type="pil", label="AI 上色結果") |
|
|
|
|
|
btn.click(colorize_ddcolor, inputs=input_i, outputs=output_i) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|
|