lingkoai commited on
Commit
49a7cc4
·
verified ·
1 Parent(s): b042958

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +63 -0
model_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_utils.py
2
+ import os
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ import onnxruntime as ort
7
+
8
+ # 若需要從 Hugging Face Hub 下載模型,可在啟動時把檔案放到 models/ 或設定環境變數 COLORIZER_ONNX
9
+ def maybe_download_from_hf(hf_path, local_path):
10
+ # placeholder: 若你要自動下載,可在這裡加入 huggingface_hub.hf_hub_download 的邏輯
11
+ # 目前預期模型已存在於 local_path(例如 models/colorizer.onnx)
12
+ return local_path
13
+
14
+ class ONNXColorizer:
15
+ def __init__(self, onnx_path="models/colorizer.onnx", provider="CPUExecutionProvider"):
16
+ self.onnx_path = maybe_download_from_hf(None, onnx_path)
17
+ # 嘗試建立 session,若失敗會拋例外
18
+ self.session = ort.InferenceSession(self.onnx_path, providers=[provider])
19
+ # 自動讀取第一個 input name 與 output name(若 export 有不同命名,請手動改)
20
+ self.input_name = self.session.get_inputs()[0].name
21
+ self.output_name = self.session.get_outputs()[0].name
22
+
23
+ def _preprocess(self, pil_img, resolution=512):
24
+ """把 PIL 圖轉成模型輸入(NCHW, float32, 0-1)"""
25
+ img = pil_img.convert("RGB").resize((resolution, resolution), Image.LANCZOS)
26
+ arr = np.array(img).astype(np.float32) / 255.0
27
+ # HWC -> NCHW
28
+ arr = np.transpose(arr, (2, 0, 1))[None, ...]
29
+ return arr
30
+
31
+ def _postprocess(self, out_arr):
32
+ """把模型輸出轉回 PIL Image(假設輸出為 NCHW float 0-1)"""
33
+ if isinstance(out_arr, list):
34
+ out_arr = out_arr[0]
35
+ # NCHW -> HWC
36
+ img = np.clip(out_arr[0].transpose(1, 2, 0) * 255.0, 0, 255).astype("uint8")
37
+ return Image.fromarray(img)
38
+
39
+ def generate(self, pil_img, style="Handpainted", resolution=512, high_res=False, draft=None):
40
+ """
41
+ 簡單的 generate 介面:
42
+ - pil_img: PIL Image(線稿)
43
+ - style: 目前僅作為參考字串,若 ONNX 支援 style token 可擴充
44
+ - resolution: 512 或 1024
45
+ - high_res: 若 True 可觸發不同流程(此範例以同一模型處理)
46
+ - draft: 若提供草稿,可作為融合輸入(需 ONNX export 支援)
47
+ """
48
+ inp = self._preprocess(pil_img, resolution=resolution)
49
+ # 若模型需要多個輸入(例如 draft 或 style token),需在這裡擴充 feed dict
50
+ feed = {self.input_name: inp}
51
+ outputs = self.session.run(None, feed)
52
+ return self._postprocess(outputs)
53
+
54
+ @staticmethod
55
+ def load_session_with_quantized(onnx_path, quantized_path="models/colorizer_int8.onnx"):
56
+ """
57
+ 若你已經做過 ONNX 動態量化並把檔案放到 quantized_path,可用此方法載入量化模型。
58
+ 量化步驟建議在 Colab 或本地先完成,再把量化檔上傳到 repo。
59
+ """
60
+ if os.path.exists(quantized_path):
61
+ return ort.InferenceSession(quantized_path, providers=["CPUExecutionProvider"])
62
+ # fallback to original
63
+ return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])