File size: 3,097 Bytes
49a7cc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# model_utils.py
import os
import io
import numpy as np
from PIL import Image
import onnxruntime as ort

# 若需要從 Hugging Face Hub 下載模型,可在啟動時把檔案放到 models/ 或設定環境變數 COLORIZER_ONNX
def maybe_download_from_hf(hf_path, local_path):
    # placeholder: 若你要自動下載,可在這裡加入 huggingface_hub.hf_hub_download 的邏輯
    # 目前預期模型已存在於 local_path(例如 models/colorizer.onnx)
    return local_path

class ONNXColorizer:
    def __init__(self, onnx_path="models/colorizer.onnx", provider="CPUExecutionProvider"):
        self.onnx_path = maybe_download_from_hf(None, onnx_path)
        # 嘗試建立 session,若失敗會拋例外
        self.session = ort.InferenceSession(self.onnx_path, providers=[provider])
        # 自動讀取第一個 input name 與 output name(若 export 有不同命名,請手動改)
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

    def _preprocess(self, pil_img, resolution=512):
        """把 PIL 圖轉成模型輸入(NCHW, float32, 0-1)"""
        img = pil_img.convert("RGB").resize((resolution, resolution), Image.LANCZOS)
        arr = np.array(img).astype(np.float32) / 255.0
        # HWC -> NCHW
        arr = np.transpose(arr, (2, 0, 1))[None, ...]
        return arr

    def _postprocess(self, out_arr):
        """把模型輸出轉回 PIL Image(假設輸出為 NCHW float 0-1)"""
        if isinstance(out_arr, list):
            out_arr = out_arr[0]
        # NCHW -> HWC
        img = np.clip(out_arr[0].transpose(1, 2, 0) * 255.0, 0, 255).astype("uint8")
        return Image.fromarray(img)

    def generate(self, pil_img, style="Handpainted", resolution=512, high_res=False, draft=None):
        """
        簡單的 generate 介面:
        - pil_img: PIL Image(線稿)
        - style: 目前僅作為參考字串,若 ONNX 支援 style token 可擴充
        - resolution: 512 或 1024
        - high_res: 若 True 可觸發不同流程(此範例以同一模型處理)
        - draft: 若提供草稿,可作為融合輸入(需 ONNX export 支援)
        """
        inp = self._preprocess(pil_img, resolution=resolution)
        # 若模型需要多個輸入(例如 draft 或 style token),需在這裡擴充 feed dict
        feed = {self.input_name: inp}
        outputs = self.session.run(None, feed)
        return self._postprocess(outputs)

    @staticmethod
    def load_session_with_quantized(onnx_path, quantized_path="models/colorizer_int8.onnx"):
        """
        若你已經做過 ONNX 動態量化並把檔案放到 quantized_path,可用此方法載入量化模型。
        量化步驟建議在 Colab 或本地先完成,再把量化檔上傳到 repo。
        """
        if os.path.exists(quantized_path):
            return ort.InferenceSession(quantized_path, providers=["CPUExecutionProvider"])
        # fallback to original
        return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])