# model_utils.py import os import io import numpy as np from PIL import Image import onnxruntime as ort # 若需要從 Hugging Face Hub 下載模型,可在啟動時把檔案放到 models/ 或設定環境變數 COLORIZER_ONNX def maybe_download_from_hf(hf_path, local_path): # placeholder: 若你要自動下載,可在這裡加入 huggingface_hub.hf_hub_download 的邏輯 # 目前預期模型已存在於 local_path(例如 models/colorizer.onnx) return local_path class ONNXColorizer: def __init__(self, onnx_path="models/colorizer.onnx", provider="CPUExecutionProvider"): self.onnx_path = maybe_download_from_hf(None, onnx_path) # 嘗試建立 session,若失敗會拋例外 self.session = ort.InferenceSession(self.onnx_path, providers=[provider]) # 自動讀取第一個 input name 與 output name(若 export 有不同命名,請手動改) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name def _preprocess(self, pil_img, resolution=512): """把 PIL 圖轉成模型輸入(NCHW, float32, 0-1)""" img = pil_img.convert("RGB").resize((resolution, resolution), Image.LANCZOS) arr = np.array(img).astype(np.float32) / 255.0 # HWC -> NCHW arr = np.transpose(arr, (2, 0, 1))[None, ...] return arr def _postprocess(self, out_arr): """把模型輸出轉回 PIL Image(假設輸出為 NCHW float 0-1)""" if isinstance(out_arr, list): out_arr = out_arr[0] # NCHW -> HWC img = np.clip(out_arr[0].transpose(1, 2, 0) * 255.0, 0, 255).astype("uint8") return Image.fromarray(img) def generate(self, pil_img, style="Handpainted", resolution=512, high_res=False, draft=None): """ 簡單的 generate 介面: - pil_img: PIL Image(線稿) - style: 目前僅作為參考字串,若 ONNX 支援 style token 可擴充 - resolution: 512 或 1024 - high_res: 若 True 可觸發不同流程(此範例以同一模型處理) - draft: 若提供草稿,可作為融合輸入(需 ONNX export 支援) """ inp = self._preprocess(pil_img, resolution=resolution) # 若模型需要多個輸入(例如 draft 或 style token),需在這裡擴充 feed dict feed = {self.input_name: inp} outputs = self.session.run(None, feed) return self._postprocess(outputs) @staticmethod def load_session_with_quantized(onnx_path, quantized_path="models/colorizer_int8.onnx"): """ 若你已經做過 ONNX 動態量化並把檔案放到 quantized_path,可用此方法載入量化模型。 量化步驟建議在 Colab 或本地先完成,再把量化檔上傳到 repo。 """ if os.path.exists(quantized_path): return ort.InferenceSession(quantized_path, providers=["CPUExecutionProvider"]) # fallback to original return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])