colorizer / model_utils.py
lingkoai's picture
Create model_utils.py
49a7cc4 verified
# model_utils.py
import os
import io
import numpy as np
from PIL import Image
import onnxruntime as ort
# 若需要從 Hugging Face Hub 下載模型,可在啟動時把檔案放到 models/ 或設定環境變數 COLORIZER_ONNX
def maybe_download_from_hf(hf_path, local_path):
# placeholder: 若你要自動下載,可在這裡加入 huggingface_hub.hf_hub_download 的邏輯
# 目前預期模型已存在於 local_path(例如 models/colorizer.onnx)
return local_path
class ONNXColorizer:
def __init__(self, onnx_path="models/colorizer.onnx", provider="CPUExecutionProvider"):
self.onnx_path = maybe_download_from_hf(None, onnx_path)
# 嘗試建立 session,若失敗會拋例外
self.session = ort.InferenceSession(self.onnx_path, providers=[provider])
# 自動讀取第一個 input name 與 output name(若 export 有不同命名,請手動改)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def _preprocess(self, pil_img, resolution=512):
"""把 PIL 圖轉成模型輸入(NCHW, float32, 0-1)"""
img = pil_img.convert("RGB").resize((resolution, resolution), Image.LANCZOS)
arr = np.array(img).astype(np.float32) / 255.0
# HWC -> NCHW
arr = np.transpose(arr, (2, 0, 1))[None, ...]
return arr
def _postprocess(self, out_arr):
"""把模型輸出轉回 PIL Image(假設輸出為 NCHW float 0-1)"""
if isinstance(out_arr, list):
out_arr = out_arr[0]
# NCHW -> HWC
img = np.clip(out_arr[0].transpose(1, 2, 0) * 255.0, 0, 255).astype("uint8")
return Image.fromarray(img)
def generate(self, pil_img, style="Handpainted", resolution=512, high_res=False, draft=None):
"""
簡單的 generate 介面:
- pil_img: PIL Image(線稿)
- style: 目前僅作為參考字串,若 ONNX 支援 style token 可擴充
- resolution: 512 或 1024
- high_res: 若 True 可觸發不同流程(此範例以同一模型處理)
- draft: 若提供草稿,可作為融合輸入(需 ONNX export 支援)
"""
inp = self._preprocess(pil_img, resolution=resolution)
# 若模型需要多個輸入(例如 draft 或 style token),需在這裡擴充 feed dict
feed = {self.input_name: inp}
outputs = self.session.run(None, feed)
return self._postprocess(outputs)
@staticmethod
def load_session_with_quantized(onnx_path, quantized_path="models/colorizer_int8.onnx"):
"""
若你已經做過 ONNX 動態量化並把檔案放到 quantized_path,可用此方法載入量化模型。
量化步驟建議在 Colab 或本地先完成,再把量化檔上傳到 repo。
"""
if os.path.exists(quantized_path):
return ort.InferenceSession(quantized_path, providers=["CPUExecutionProvider"])
# fallback to original
return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])