Ha Trong Nguyen commited on
Commit
03eb31b
·
1 Parent(s): efe5fd6

feat: release final optimized ONNX 320x320 pipeline

Browse files
.gitattributes CHANGED
@@ -1 +1,3 @@
1
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
3
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -23,7 +23,7 @@ USER user
23
  ENV HOME=/home/user \
24
  PATH=/home/user/.local/bin:$PATH \
25
  PYTHONUNBUFFERED=1 \
26
- TF_ZIP_MODEL_PATH=/app/ZIP/checkpoints/demo_data/best_mae_0.pth \
27
  TF_HOST=0.0.0.0 \
28
  TF_PORT=7860
29
 
 
23
  ENV HOME=/home/user \
24
  PATH=/home/user/.local/bin:$PATH \
25
  PYTHONUNBUFFERED=1 \
26
+ TF_ZIP_MODEL_PATH=/app/ZIP/checkpoints/demo_data/best_mae_0_quantized.onnx \
27
  TF_HOST=0.0.0.0 \
28
  TF_PORT=7860
29
 
ZIP/checkpoints/demo_data/best_mae_0_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a91f8817cd924a404893e020c2cc68d3ea2383c0203f070429463d19f0bf67e
3
+ size 110797385
backend/config.py CHANGED
@@ -18,11 +18,10 @@ class Settings(BaseSettings):
18
  zip_model_path: str = os.path.join(
19
  os.path.dirname(__file__),
20
  "..",
21
- "..",
22
  "ZIP",
23
  "checkpoints",
24
  "demo_data",
25
- "best_mae_0.pth",
26
  )
27
  zip_model_device: str = "cpu" # "cuda" or "cpu"
28
  zip_input_size: int = 320
 
18
  zip_model_path: str = os.path.join(
19
  os.path.dirname(__file__),
20
  "..",
 
21
  "ZIP",
22
  "checkpoints",
23
  "demo_data",
24
+ "best_mae_0_quantized.onnx",
25
  )
26
  zip_model_device: str = "cpu" # "cuda" or "cpu"
27
  zip_input_size: int = 320
backend/model_service.py CHANGED
@@ -33,6 +33,7 @@ class ZIPModelService:
33
  self.device = None
34
  self.input_size = 448
35
  self._loaded = False
 
36
 
37
  @classmethod
38
  def get_instance(cls) -> "ZIPModelService":
@@ -53,12 +54,33 @@ class ZIPModelService:
53
  logger.info(f"[load_model] Device: {self.device}, Input size: {input_size}")
54
 
55
  try:
56
- from models import get_model
 
57
 
58
- self.model = get_model(model_info_path=model_path)
59
- self.model.to(self.device)
60
- self.model.eval()
61
- self._loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if hasattr(self.model, "config"):
64
  logger.info(
@@ -108,7 +130,12 @@ class ZIPModelService:
108
  tensor = self._preprocess_image(image_rgb)
109
 
110
  with torch.no_grad():
111
- model_out = self.model(tensor)
 
 
 
 
 
112
 
113
  inference_time_ms = round((time.time() - start_time) * 1000, 1)
114
 
 
33
  self.device = None
34
  self.input_size = 448
35
  self._loaded = False
36
+ self.is_onnx = False
37
 
38
  @classmethod
39
  def get_instance(cls) -> "ZIPModelService":
 
54
  logger.info(f"[load_model] Device: {self.device}, Input size: {input_size}")
55
 
56
  try:
57
+ if model_path.endswith(".onnx"):
58
+ import onnxruntime as ort
59
 
60
+ logger.info("[load_model] Auto-activating ONNX Runtime")
61
+
62
+ sess_options = ort.SessionOptions()
63
+ sess_options.intra_op_num_threads = 2
64
+ sess_options.graph_optimization_level = (
65
+ ort.GraphOptimizationLevel.ORT_ENABLE_ALL
66
+ )
67
+
68
+ self.model = ort.InferenceSession(
69
+ model_path,
70
+ sess_options=sess_options,
71
+ providers=["CPUExecutionProvider"],
72
+ )
73
+ self.is_onnx = True
74
+ self._loaded = True
75
+ logger.info("[load_model] - ONNX model loaded successfully.")
76
+ else:
77
+ from models import get_model
78
+
79
+ self.model = get_model(model_info_path=model_path)
80
+ self.model.to(self.device)
81
+ self.model.eval()
82
+ self.is_onnx = False
83
+ self._loaded = True
84
 
85
  if hasattr(self.model, "config"):
86
  logger.info(
 
130
  tensor = self._preprocess_image(image_rgb)
131
 
132
  with torch.no_grad():
133
+ if self.is_onnx:
134
+ ort_inputs = {self.model.get_inputs()[0].name: tensor.cpu().numpy()}
135
+ ort_outs = self.model.run(None, ort_inputs)
136
+ model_out = torch.tensor(ort_outs[0])
137
+ else:
138
+ model_out = self.model(tensor)
139
 
140
  inference_time_ms = round((time.time() - start_time) * 1000, 1)
141
 
backend/requirements.txt CHANGED
@@ -21,3 +21,5 @@ scipy
21
  peft
22
  numpy
23
  PyTurboJPEG
 
 
 
21
  peft
22
  numpy
23
  PyTurboJPEG
24
+ onnx
25
+ onnxruntime
convert_to_onnx.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import subprocess
5
+
6
+ # Sửa lỗi Unicode trên Windows Terminal
7
+ sys.stdout.reconfigure(encoding='utf-8')
8
+ sys.stderr.reconfigure(encoding='utf-8')
9
+
10
+ # Đảm bảo cài đặt các thư viện cần thiết cho quá trình chuyển đổi ONNX
11
+ def install_requirements():
12
+ print("[INFO] Đang kiểm tra thư viện ONNX...")
13
+ try:
14
+ import onnx
15
+ import onnxruntime
16
+ import onnxscript
17
+ except ImportError:
18
+ print("[PROCESS] Cài đặt onnx, onnxruntime và onnxscript...")
19
+ subprocess.check_call(
20
+ [sys.executable, "-m", "pip", "install", "onnx", "onnxruntime", "onnxscript"]
21
+ )
22
+ import onnx
23
+ import onnxruntime
24
+ import onnxscript
25
+
26
+ print("[SUCCESS] Đã cài đặt xong.")
27
+
28
+
29
+ install_requirements()
30
+
31
+ import torch
32
+ from onnxruntime.quantization import quantize_dynamic, QuantType
33
+ import logging
34
+
35
+ # Thiết lập đường dẫn để import models
36
+ ZIP_PROJECT_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "ZIP"))
37
+ if ZIP_PROJECT_ROOT not in sys.path:
38
+ sys.path.insert(0, ZIP_PROJECT_ROOT)
39
+
40
+ try:
41
+ from models import get_model
42
+ except ImportError as e:
43
+ print(f"[ERROR] Không thể nạp module models: {e}")
44
+ sys.exit(1)
45
+
46
+
47
+ def convert_to_onnx(model_path, output_path, input_size=448):
48
+ print(f"\n[INFO] BẮT ĐẦU QUÁ TRÌNH CHUYỂN ĐỔI ONNX")
49
+ print(f"[INFO] Nguồn PyTorch: {model_path}")
50
+
51
+ if not os.path.exists(model_path):
52
+ print(f"[ERROR] Không tìm thấy file model: {model_path}")
53
+ return None
54
+
55
+ # 1. Tải mô hình PyTorch
56
+ print("[PROCESS] Đang tải mô hình PyTorch lên RAM...")
57
+ model = get_model(model_info_path=model_path)
58
+ model.eval()
59
+ model.to("cpu")
60
+ print("[SUCCESS] Tải mô hình thành công.")
61
+
62
+ # 2. Tạo Dummy Input (Ảnh giả lập)
63
+ print(
64
+ f"[INFO] Kích thước đầu vào (Input Shape): [1, 3, {input_size}, {input_size}]"
65
+ )
66
+ dummy_input = torch.randn(1, 3, input_size, input_size)
67
+
68
+ # 3. Xuất sang định dạng ONNX Float32
69
+ print("[PROCESS] Đang Compile và Export sang ONNX (Float32)...")
70
+ torch.onnx.export(
71
+ model,
72
+ dummy_input,
73
+ output_path,
74
+ export_params=True,
75
+ opset_version=18,
76
+ do_constant_folding=True,
77
+ input_names=["input"],
78
+ output_names=["output"],
79
+ dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
80
+ )
81
+ print(f"[SUCCESS] Đã xuất file ONNX gốc (Float32): {output_path}")
82
+ print(f"[INFO] Dung lượng: {os.path.getsize(output_path) / (1024*1024):.2f} MB")
83
+
84
+ return output_path
85
+
86
+
87
+ def quantize_onnx(onnx_path, quantized_path):
88
+ print("\n[INFO] BẮT ĐẦU LƯỢNG TỬ HOÁ (QUANTIZATION INT8)")
89
+ print("[INFO] Quá trình này giúp mô hình nhẹ hơn x4 lần và tối ưu cho CPU.")
90
+
91
+ try:
92
+ quantize_dynamic(
93
+ model_input=onnx_path,
94
+ model_output=quantized_path,
95
+ weight_type=QuantType.QUInt8,
96
+ )
97
+ print(f"[SUCCESS] Đã tạo file ONNX Quantized (INT8): {quantized_path}")
98
+ print(
99
+ f"[INFO] Dung lượng mới: {os.path.getsize(quantized_path) / (1024*1024):.2f} MB"
100
+ )
101
+ print(
102
+ "[INFO] Có thể sử dụng file này để deploy lên Hugging Face hoặc thiết bị biên."
103
+ )
104
+ except Exception as e:
105
+ print(f"[ERROR] Lỗi khi lượng tử hoá: {e}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser(
110
+ description="Chuyển đổi PyTorch Model sang ONNX và Quantize INT8"
111
+ )
112
+ parser.add_argument(
113
+ "--model",
114
+ type=str,
115
+ default="ZIP/checkpoints/demo_data/best_mae_0.pth",
116
+ help="Đường dẫn file .pth gốc",
117
+ )
118
+ parser.add_argument(
119
+ "--size",
120
+ type=int,
121
+ default=448,
122
+ help="Kích thước input_size (ví dụ: 448 hoặc 320)",
123
+ )
124
+ args = parser.parse_args()
125
+
126
+ # Tạo tên file ONNX đầu ra
127
+ base_name = os.path.splitext(args.model)[0]
128
+ onnx_fp32_path = f"{base_name}.onnx"
129
+ onnx_int8_path = f"{base_name}_quantized.onnx"
130
+
131
+ # Chạy quy trình
132
+ exported_onnx = convert_to_onnx(args.model, onnx_fp32_path, args.size)
133
+ if exported_onnx:
134
+ quantize_onnx(exported_onnx, onnx_int8_path)