|
|
|
|
|
import os |
|
|
import io |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import onnxruntime as ort |
|
|
|
|
|
|
|
|
def maybe_download_from_hf(hf_path, local_path): |
|
|
|
|
|
|
|
|
return local_path |
|
|
|
|
|
class ONNXColorizer: |
|
|
def __init__(self, onnx_path="models/colorizer.onnx", provider="CPUExecutionProvider"): |
|
|
self.onnx_path = maybe_download_from_hf(None, onnx_path) |
|
|
|
|
|
self.session = ort.InferenceSession(self.onnx_path, providers=[provider]) |
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name |
|
|
self.output_name = self.session.get_outputs()[0].name |
|
|
|
|
|
def _preprocess(self, pil_img, resolution=512): |
|
|
"""把 PIL 圖轉成模型輸入(NCHW, float32, 0-1)""" |
|
|
img = pil_img.convert("RGB").resize((resolution, resolution), Image.LANCZOS) |
|
|
arr = np.array(img).astype(np.float32) / 255.0 |
|
|
|
|
|
arr = np.transpose(arr, (2, 0, 1))[None, ...] |
|
|
return arr |
|
|
|
|
|
def _postprocess(self, out_arr): |
|
|
"""把模型輸出轉回 PIL Image(假設輸出為 NCHW float 0-1)""" |
|
|
if isinstance(out_arr, list): |
|
|
out_arr = out_arr[0] |
|
|
|
|
|
img = np.clip(out_arr[0].transpose(1, 2, 0) * 255.0, 0, 255).astype("uint8") |
|
|
return Image.fromarray(img) |
|
|
|
|
|
def generate(self, pil_img, style="Handpainted", resolution=512, high_res=False, draft=None): |
|
|
""" |
|
|
簡單的 generate 介面: |
|
|
- pil_img: PIL Image(線稿) |
|
|
- style: 目前僅作為參考字串,若 ONNX 支援 style token 可擴充 |
|
|
- resolution: 512 或 1024 |
|
|
- high_res: 若 True 可觸發不同流程(此範例以同一模型處理) |
|
|
- draft: 若提供草稿,可作為融合輸入(需 ONNX export 支援) |
|
|
""" |
|
|
inp = self._preprocess(pil_img, resolution=resolution) |
|
|
|
|
|
feed = {self.input_name: inp} |
|
|
outputs = self.session.run(None, feed) |
|
|
return self._postprocess(outputs) |
|
|
|
|
|
@staticmethod |
|
|
def load_session_with_quantized(onnx_path, quantized_path="models/colorizer_int8.onnx"): |
|
|
""" |
|
|
若你已經做過 ONNX 動態量化並把檔案放到 quantized_path,可用此方法載入量化模型。 |
|
|
量化步驟建議在 Colab 或本地先完成,再把量化檔上傳到 repo。 |
|
|
""" |
|
|
if os.path.exists(quantized_path): |
|
|
return ort.InferenceSession(quantized_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
|