update
Browse files- app.py +12 -10
- pth2onnx.py +16 -9
app.py
CHANGED
|
@@ -197,7 +197,8 @@ def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min:
|
|
| 197 |
os.remove(save_path)
|
| 198 |
return None
|
| 199 |
|
| 200 |
-
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int):
|
|
|
|
| 201 |
|
| 202 |
log = ('初始化日志记录...\n')
|
| 203 |
print_log(task_id, '初始化日志记录', '开始')
|
|
@@ -226,7 +227,7 @@ async def _process_model(model_input: Union[str, gr.File], tilesize: int, output
|
|
| 226 |
yield [],log
|
| 227 |
else:
|
| 228 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
| 229 |
-
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset)
|
| 230 |
if onnx_path:
|
| 231 |
log += ( f'成功生成ONNX模型: {onnx_path}\n')
|
| 232 |
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
|
@@ -277,12 +278,11 @@ with gr.Blocks() as demo:
|
|
| 277 |
return gr.update(visible=False), gr.update(visible=True)
|
| 278 |
|
| 279 |
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# 添加fp16和try_run复选框
|
| 283 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
| 284 |
onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
|
| 285 |
-
|
| 286 |
try_run = gr.Checkbox(label="MNNSR test", value=False)
|
| 287 |
convert_btn = gr.Button("Run")
|
| 288 |
with gr.Column():
|
|
@@ -298,7 +298,8 @@ with gr.Blocks() as demo:
|
|
| 298 |
return gr.update(visible=False)
|
| 299 |
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
| 300 |
|
| 301 |
-
async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run):
|
|
|
|
| 302 |
|
| 303 |
global task_counter
|
| 304 |
task_counter += 1
|
|
@@ -347,7 +348,8 @@ with gr.Blocks() as demo:
|
|
| 347 |
onnx_path = None
|
| 348 |
mnn_path = None
|
| 349 |
# 调用重命名后的函数
|
| 350 |
-
async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset):
|
|
|
|
| 351 |
if isinstance(result, tuple) and len(result) == 3:
|
| 352 |
onnx_path, mnn_path, process_log = result
|
| 353 |
yield onnx_path, mnn_path, log+process_log, None
|
|
@@ -360,7 +362,7 @@ with gr.Blocks() as demo:
|
|
| 360 |
if mnn_path:
|
| 361 |
if try_run:
|
| 362 |
print_log(task_counter, f'测试模型: {mnn_path}', '开始')
|
| 363 |
-
processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg",
|
| 364 |
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
| 365 |
# processed_image_pil = Image.fromarray(processed_image_np)
|
| 366 |
yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
|
|
@@ -370,7 +372,7 @@ with gr.Blocks() as demo:
|
|
| 370 |
|
| 371 |
convert_btn.click(
|
| 372 |
process_model,
|
| 373 |
-
inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, try_run],
|
| 374 |
outputs=[onnx_output, mnn_output, log_box, img_output],
|
| 375 |
api_name="convert_nmm_model"
|
| 376 |
)
|
|
|
|
| 197 |
os.remove(save_path)
|
| 198 |
return None
|
| 199 |
|
| 200 |
+
async def _process_model(model_input: Union[str, gr.File], tilesize: int, output_dir: str,task_id:int,fp16:bool,onnxsim:bool,opset:int,dynamic_axes:bool):
|
| 201 |
+
|
| 202 |
|
| 203 |
log = ('初始化日志记录...\n')
|
| 204 |
print_log(task_id, '初始化日志记录', '开始')
|
|
|
|
| 227 |
yield [],log
|
| 228 |
else:
|
| 229 |
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
| 230 |
+
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir,use_fp16=fp16, simplify_model=onnxsim, opset=opset, dynamic_axes=dynamic_axes)
|
| 231 |
if onnx_path:
|
| 232 |
log += ( f'成功生成ONNX模型: {onnx_path}\n')
|
| 233 |
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
|
|
|
| 278 |
return gr.update(visible=False), gr.update(visible=True)
|
| 279 |
|
| 280 |
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
| 281 |
+
tilesize = gr.Number(label="Dummy input width/height, default 64", value=64, precision=0)
|
| 282 |
+
opset = gr.Number(label="ONNX export opset version, suggest 9/11/13/16/17/18", value=13, precision=0)
|
|
|
|
| 283 |
fp16 = gr.Checkbox(label="FP16", value=False)
|
| 284 |
onnxsim = gr.Checkbox(label="ONNX export simplify model", value=False)
|
| 285 |
+
dynamic_axes = gr.Checkbox(label="ONNX input apply dynamic axes", value=True)
|
| 286 |
try_run = gr.Checkbox(label="MNNSR test", value=False)
|
| 287 |
convert_btn = gr.Button("Run")
|
| 288 |
with gr.Column():
|
|
|
|
| 298 |
return gr.update(visible=False)
|
| 299 |
try_run.change(show_try_run, inputs=try_run, outputs=img_output)
|
| 300 |
|
| 301 |
+
async def process_model(input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run):
|
| 302 |
+
|
| 303 |
|
| 304 |
global task_counter
|
| 305 |
task_counter += 1
|
|
|
|
| 348 |
onnx_path = None
|
| 349 |
mnn_path = None
|
| 350 |
# 调用重命名后的函数
|
| 351 |
+
async for result in _process_model(model_input, tilesize if tilesize>0 else 64, output_dir, task_counter, fp16, onnxsim, opset, dynamic_axes):
|
| 352 |
+
|
| 353 |
if isinstance(result, tuple) and len(result) == 3:
|
| 354 |
onnx_path, mnn_path, process_log = result
|
| 355 |
yield onnx_path, mnn_path, log+process_log, None
|
|
|
|
| 362 |
if mnn_path:
|
| 363 |
if try_run:
|
| 364 |
print_log(task_counter, f'测试模型: {mnn_path}', '开始')
|
| 365 |
+
processed_image_np, load_time, infer_time = modelTest_for_gradio(mnn_path, "./sample.jpg", tilesize if tilesize>0 and dynamic_axes else 0, 0)
|
| 366 |
processed_image_pil = Image.fromarray(cv2.cvtColor(processed_image_np, cv2.COLOR_BGR2RGB))
|
| 367 |
# processed_image_pil = Image.fromarray(processed_image_np)
|
| 368 |
yield onnx_path, mnn_path, log+process_log+f"MNNSR 加载模型用时 {load_time:.4f} 秒, 推理({tilesize} px)用时 {infer_time:.4f} 秒", processed_image_pil
|
|
|
|
| 372 |
|
| 373 |
convert_btn.click(
|
| 374 |
process_model,
|
| 375 |
+
inputs=[input_type, url_input, file_input, tilesize, fp16, onnxsim, opset, dynamic_axes, try_run],
|
| 376 |
outputs=[onnx_output, mnn_output, log_box, img_output],
|
| 377 |
api_name="convert_nmm_model"
|
| 378 |
)
|
pth2onnx.py
CHANGED
|
@@ -6,7 +6,7 @@ import onnx
|
|
| 6 |
from spandrel import ImageModelDescriptor, ModelLoader
|
| 7 |
from onnxsim import simplify
|
| 8 |
|
| 9 |
-
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11):
|
| 10 |
"""
|
| 11 |
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
| 12 |
|
|
@@ -86,6 +86,14 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
| 86 |
print(f"ONNX model exporting...")
|
| 87 |
try:
|
| 88 |
# Export the model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
torch.onnx.export(
|
| 90 |
torch_model, # The model instance
|
| 91 |
example_input, # An example input tensor
|
|
@@ -95,10 +103,7 @@ def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tiles
|
|
| 95 |
do_constant_folding=True, # Whether to execute constant folding for optimization
|
| 96 |
input_names=['input'], # The model's input names
|
| 97 |
output_names=['output'], # The model's output names
|
| 98 |
-
dynamic_axes=
|
| 99 |
-
"input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
|
| 100 |
-
"output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
|
| 101 |
-
}
|
| 102 |
)
|
| 103 |
print(f"ONNX model export successful: {onnx_path}")
|
| 104 |
|
|
@@ -133,16 +138,18 @@ if __name__ == "__main__":
|
|
| 133 |
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
| 134 |
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
| 135 |
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
|
|
|
| 136 |
args = parser.parse_args()
|
| 137 |
|
| 138 |
success = convert_pth_to_onnx(
|
| 139 |
-
pth_path=args.
|
| 140 |
-
onnx_path=args.
|
| 141 |
channel=args.channel,
|
| 142 |
tilesize=args.tilesize,
|
| 143 |
-
use_fp16=args.
|
| 144 |
-
simplify_model=args.
|
| 145 |
opset=args.opset,
|
|
|
|
| 146 |
)
|
| 147 |
|
| 148 |
if success:
|
|
|
|
| 6 |
from spandrel import ImageModelDescriptor, ModelLoader
|
| 7 |
from onnxsim import simplify
|
| 8 |
|
| 9 |
+
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None, opset: int = 11, dynamic_axes: bool = True):
|
| 10 |
"""
|
| 11 |
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
| 12 |
|
|
|
|
| 86 |
print(f"ONNX model exporting...")
|
| 87 |
try:
|
| 88 |
# Export the model
|
| 89 |
+
if dynamic_axes:
|
| 90 |
+
axes = {
|
| 91 |
+
"input": {2: "height", 3: "width"},
|
| 92 |
+
"output": {2: "height", 3: "width"},
|
| 93 |
+
}
|
| 94 |
+
else:
|
| 95 |
+
axes = {}
|
| 96 |
+
|
| 97 |
torch.onnx.export(
|
| 98 |
torch_model, # The model instance
|
| 99 |
example_input, # An example input tensor
|
|
|
|
| 103 |
do_constant_folding=True, # Whether to execute constant folding for optimization
|
| 104 |
input_names=['input'], # The model's input names
|
| 105 |
output_names=['output'], # The model's output names
|
| 106 |
+
dynamic_axes=axes
|
|
|
|
|
|
|
|
|
|
| 107 |
)
|
| 108 |
print(f"ONNX model export successful: {onnx_path}")
|
| 109 |
|
|
|
|
| 138 |
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
| 139 |
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
| 140 |
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version.')
|
| 141 |
+
parser.add_argument('--fixed_axes', action='store_true', help='Use dynamic axes.')
|
| 142 |
args = parser.parse_args()
|
| 143 |
|
| 144 |
success = convert_pth_to_onnx(
|
| 145 |
+
pth_path=args.pthpath,
|
| 146 |
+
onnx_path=args.onnxpath,
|
| 147 |
channel=args.channel,
|
| 148 |
tilesize=args.tilesize,
|
| 149 |
+
use_fp16=args.fp16,
|
| 150 |
+
simplify_model=args.simplify,
|
| 151 |
opset=args.opset,
|
| 152 |
+
dynamic_axes= not args.fixed_axes,
|
| 153 |
)
|
| 154 |
|
| 155 |
if success:
|