Upload 5 files
Browse files- app.py +252 -0
- onnx2mnn2.bat +121 -0
- pth2onnx.bat +36 -0
- pth2onnx.py +149 -0
- requirements.txt +6 -0
app.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
from typing import Union
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pth2onnx import convert_pth_to_onnx
|
| 8 |
+
from urllib.parse import urlparse
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
# 新增日志开关
|
| 12 |
+
log_to_terminal = True
|
| 13 |
+
|
| 14 |
+
# 新增全局任务计数器
|
| 15 |
+
task_counter = 0
|
| 16 |
+
|
| 17 |
+
# 新增日志函数
|
| 18 |
+
def print_log(task_id, stage, status):
|
| 19 |
+
if log_to_terminal:
|
| 20 |
+
print(f"任务{task_id}: [{status}] {stage}")
|
| 21 |
+
|
| 22 |
+
# 使用 MNN 库自带的转换工具
|
| 23 |
+
def convertmnn(onnx_path: str, mnn_path: str, fp16=False):
|
| 24 |
+
param = ['mnnconvert', '-f', 'ONNX', '--modelFile', onnx_path, '--MNNModel', mnn_path, '--bizCode', 'biz', '--info', '--detectSparseSpeedUp']
|
| 25 |
+
if fp16:
|
| 26 |
+
param.append('--fp16')
|
| 27 |
+
subprocess.run(param, check=True)
|
| 28 |
+
|
| 29 |
+
def download_file(url: str, save_path: str):
|
| 30 |
+
response = requests.get(url)
|
| 31 |
+
with open(save_path, 'wb') as f:
|
| 32 |
+
f.write(response.content)
|
| 33 |
+
|
| 34 |
+
def download_file2folder(url: str, folder: str, filesize_max: int, filesize_min: int) -> Optional[str]:
|
| 35 |
+
"""
|
| 36 |
+
从URL下载文件到指定文件夹,并进行文件大小检查
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
url: 要下载的文件URL
|
| 40 |
+
folder: 保存文件的目标文件夹路径
|
| 41 |
+
filesize_max: 最大允许文件大小(字节),超过此值将中断下载并删除文件
|
| 42 |
+
filesize_min: 最小允许文件大小(字节),小于此值将删除文件并返回None
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
成功下载的文件名,如果下载失败或文件大小不符合要求则返回None
|
| 46 |
+
"""
|
| 47 |
+
# 解析URL获取文件名
|
| 48 |
+
parsed_url = urlparse(url)
|
| 49 |
+
filename = os.path.basename(parsed_url.path)
|
| 50 |
+
if not filename:
|
| 51 |
+
return None # 无法从URL获取文件名
|
| 52 |
+
|
| 53 |
+
# 确保目标文件夹存在
|
| 54 |
+
os.makedirs(folder, exist_ok=True)
|
| 55 |
+
save_path = os.path.join(folder, filename)
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# 发送HTTP请求,流式下载
|
| 59 |
+
with requests.get(url, stream=True, timeout=10) as response:
|
| 60 |
+
response.raise_for_status() # 检查HTTP错误状态
|
| 61 |
+
|
| 62 |
+
# 获取文件总大小(如果服务器提供)
|
| 63 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 64 |
+
if total_size > filesize_max:
|
| 65 |
+
return None # 文件大小超过最大值,不下载
|
| 66 |
+
|
| 67 |
+
downloaded_size = 0
|
| 68 |
+
with open(save_path, 'wb') as file:
|
| 69 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 70 |
+
if chunk: # 过滤空块
|
| 71 |
+
downloaded_size += len(chunk)
|
| 72 |
+
# 检查是否超过最大允许大小
|
| 73 |
+
if downloaded_size > filesize_max:
|
| 74 |
+
file.close()
|
| 75 |
+
os.remove(save_path)
|
| 76 |
+
return None
|
| 77 |
+
file.write(chunk)
|
| 78 |
+
|
| 79 |
+
# 下载完成后检查最小文件大小
|
| 80 |
+
if os.path.getsize(save_path) < filesize_min:
|
| 81 |
+
os.remove(save_path)
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
return filename
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
# 发生异常时清理文件
|
| 88 |
+
if os.path.exists(save_path):
|
| 89 |
+
os.remove(save_path)
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
# 原 process_model 函数重命名为 _process_model
|
| 93 |
+
async def _process_model(model_input: Union[str, gr.File], tilesize: int, log_box: gr.Textbox, output_dir: str):
|
| 94 |
+
global task_counter
|
| 95 |
+
task_id = task_counter
|
| 96 |
+
log = ('初始化日志记录...')
|
| 97 |
+
print_log(task_id, '初始化日志记录', '开始')
|
| 98 |
+
yield [],[], log
|
| 99 |
+
|
| 100 |
+
# 处理输入模型
|
| 101 |
+
if isinstance(model_input, str): # 处理链接
|
| 102 |
+
if model_input.startswith(('http://', 'https://')):
|
| 103 |
+
log += ( f'正在下载模型文件: {model_input}')
|
| 104 |
+
print_log(task_id, f'正在下载模型文件: {model_input}', '开始')
|
| 105 |
+
yield [],[], log
|
| 106 |
+
|
| 107 |
+
# 下载文件到output文件夹
|
| 108 |
+
filename = download_file2folder(
|
| 109 |
+
url=model_input,
|
| 110 |
+
folder=output_dir,
|
| 111 |
+
filesize_max=200*1024*1024, # 200MB
|
| 112 |
+
filesize_min=1024 # 1KB
|
| 113 |
+
)
|
| 114 |
+
input_path = os.path.join(output_dir, filename)
|
| 115 |
+
log += ( f'模型文件已下载到: {input_path}')
|
| 116 |
+
print_log(task_id, f'模型文件已下载到: {input_path}', '完成')
|
| 117 |
+
yield [],[], log
|
| 118 |
+
else:
|
| 119 |
+
input_path = model_input
|
| 120 |
+
log += ( f'使用本地文件: {input_path}')
|
| 121 |
+
print_log(task_id, f'使用本地文件: {input_path}', '开始')
|
| 122 |
+
yield [],[], log
|
| 123 |
+
else:
|
| 124 |
+
input_path = model_input.name
|
| 125 |
+
log += ( f'已上传模型文件: {input_path}')
|
| 126 |
+
print_log(task_id, f'已上传模型文件: {input_path}', '开始')
|
| 127 |
+
yield [],[], log
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if not input_path:
|
| 131 |
+
log += ( f'未获得正确的模型文件')
|
| 132 |
+
print_log(task_id, f'未获得正确的模型文件', '错误')
|
| 133 |
+
yield [],[], log
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if input_path.endswith('.onnx'):
|
| 138 |
+
onnx_path = input_path
|
| 139 |
+
log += ( '输入已是 ONNX 模型,直接使用...')
|
| 140 |
+
print_log(task_id, '输入已是 ONNX 模型,直接使用', '开始')
|
| 141 |
+
yield [],[], log
|
| 142 |
+
else:
|
| 143 |
+
print_log(task_id, f'转换 PTH 模型为 ONNX, folder={output_dir}', '开始')
|
| 144 |
+
onnx_path = convert_pth_to_onnx(input_path, tilesize=tilesize, output_folder=output_dir)
|
| 145 |
+
if onnx_path:
|
| 146 |
+
log += ( f'成功生成ONNX模型: {onnx_path}')
|
| 147 |
+
print_log(task_id, f'生成ONNX模型: {onnx_path}', '完成')
|
| 148 |
+
else:
|
| 149 |
+
log += ( '生成ONNX模型失败')
|
| 150 |
+
print_log(task_id, '生成ONNX模型', '错误')
|
| 151 |
+
yield [], [], log
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# 转换为 MNN 模型
|
| 156 |
+
output_name= os.path.splitext(os.path.basename(onnx_path))[0]
|
| 157 |
+
mnn_path = os.path.join(output_dir, f'{output_name}.mnn')
|
| 158 |
+
try:
|
| 159 |
+
log += ( '正在将 ONNX 模型转换为 MNN 格式...')
|
| 160 |
+
print_log(task_id, '正在将 ONNX 模型转换为 MNN 格式', '开始')
|
| 161 |
+
convertmnn(onnx_path, mnn_path)
|
| 162 |
+
yield onnx_path,[], log
|
| 163 |
+
except Exception as e:
|
| 164 |
+
log += ( f'转换 MNN 模型时出错: {str(e)}')
|
| 165 |
+
print_log(task_id, f'转换 MNN 模型时出错: {str(e)}', '错误')
|
| 166 |
+
yield onnx_path,[], log
|
| 167 |
+
|
| 168 |
+
print_log(task_id, '模型转换任务完成', '完成')
|
| 169 |
+
|
| 170 |
+
# 转换为 MNN 模型后对文件检查
|
| 171 |
+
if os.path.exists(mnn_path) and os.path.getsize(mnn_path) > 1024: # 1KB = 1024 bytes
|
| 172 |
+
log += ( f'MNN 模型已保存到: {mnn_path}')
|
| 173 |
+
else:
|
| 174 |
+
log += ( 'MNN 模型生成失败或文件大小不足1KB')
|
| 175 |
+
mnn_path = None
|
| 176 |
+
|
| 177 |
+
yield onnx_path, mnn_path, log
|
| 178 |
+
|
| 179 |
+
with gr.Blocks() as demo:
|
| 180 |
+
gr.Markdown("# 模型转换工具")
|
| 181 |
+
with gr.Row():
|
| 182 |
+
with gr.Column():
|
| 183 |
+
input_type = gr.Radio(['模型链接', '上传模型文件'], label='输入类型')
|
| 184 |
+
url_input = gr.Textbox(label='模型链接')
|
| 185 |
+
file_input = gr.File(label='上传模型文件', visible=False)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def show_input(input_type):
|
| 189 |
+
if input_type == '模型链接':
|
| 190 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 191 |
+
else:
|
| 192 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 193 |
+
|
| 194 |
+
input_type.change(show_input, inputs=input_type, outputs=[url_input, file_input])
|
| 195 |
+
|
| 196 |
+
tilesize = gr.Number(label="Tilesize", value=0, precision=0)
|
| 197 |
+
convert_btn = gr.Button("开始转换")
|
| 198 |
+
with gr.Column():
|
| 199 |
+
log_box = gr.Textbox(label="转换日志", lines=10, interactive=False)
|
| 200 |
+
with gr.Row():
|
| 201 |
+
onnx_output = gr.File(label="ONNX 模型输出")
|
| 202 |
+
mnn_output = gr.File(label="MNN 模型输出")
|
| 203 |
+
|
| 204 |
+
async def process_model(input_type, url_input, file_input, tilesize, log_box):
|
| 205 |
+
if input_type == '模型链接' and url_input:
|
| 206 |
+
model_input = url_input
|
| 207 |
+
elif input_type == '上传模型文件' and file_input:
|
| 208 |
+
model_input = file_input
|
| 209 |
+
else:
|
| 210 |
+
# 改为通过yield返回错误日志
|
| 211 |
+
log = '\n请选择输入类型并提供有效的输入!'
|
| 212 |
+
yield None, None, log
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
# 创建不重名的输出目录
|
| 216 |
+
global task_counter
|
| 217 |
+
task_counter += 1
|
| 218 |
+
output_dir = os.path.join(os.getcwd(), f"output_{task_counter}")
|
| 219 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 220 |
+
|
| 221 |
+
onnx_path = None
|
| 222 |
+
mnn_path = None
|
| 223 |
+
# 调用重命名后的函数
|
| 224 |
+
async for result in _process_model(model_input, int(tilesize), log_box, output_dir):
|
| 225 |
+
if isinstance(result, tuple) and len(result) == 3:
|
| 226 |
+
onnx_path, mnn_path, log_box = result
|
| 227 |
+
elif isinstance(result, tuple) and len(result) == 2:
|
| 228 |
+
# 处理纯日志yield
|
| 229 |
+
_, process_log = result
|
| 230 |
+
yield None, None, process_log
|
| 231 |
+
yield onnx_path, mnn_path, log_box
|
| 232 |
+
|
| 233 |
+
convert_btn.click(
|
| 234 |
+
process_model,
|
| 235 |
+
inputs=[input_type, url_input, file_input, tilesize, log_box],
|
| 236 |
+
outputs=[onnx_output, mnn_output, log_box],
|
| 237 |
+
api_name="convert_model"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# 将示例移至底部并包裹在列组件中
|
| 241 |
+
examples_column = gr.Column(visible=True)
|
| 242 |
+
with examples_column:
|
| 243 |
+
examples = [
|
| 244 |
+
["模型链接", "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"],
|
| 245 |
+
["模型链接", "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth"],
|
| 246 |
+
["模型链接", "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth"],
|
| 247 |
+
["模型链接", "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth"],
|
| 248 |
+
["模型链接", "https://github.com/Kim2091/Kim2091-Models/releases/download/2x-AnimeSharpV4/2x-AnimeSharpV4_RCAN_fp16_op17.onnx"]
|
| 249 |
+
]
|
| 250 |
+
example_input = gr.Examples(examples=examples, inputs=[input_type, url_input], label='示例模型链接')
|
| 251 |
+
|
| 252 |
+
demo.launch()
|
onnx2mnn2.bat
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal enabledelayedexpansion
|
| 3 |
+
|
| 4 |
+
rem ���̶�shape��fp16����ͼƬ����
|
| 5 |
+
rem ����Ƿ��в�������
|
| 6 |
+
if "%~1"=="" (
|
| 7 |
+
rem ������������ʾ�û������ļ�·��
|
| 8 |
+
set /p onnx_path="������ .onnx �ļ���·��: "
|
| 9 |
+
) else (
|
| 10 |
+
rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
|
| 11 |
+
set "onnx_path=%~1"
|
| 12 |
+
cd /d %~dp0
|
| 13 |
+
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
:main
|
| 17 |
+
rem ����ļ��Ƿ�������Ƿ��� .onnx ��β
|
| 18 |
+
if not exist "!onnx_path!" (
|
| 19 |
+
echo ����: �ļ������ڣ�����������·����
|
| 20 |
+
goto loop
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
set "onnx_ext=!onnx_path:~-5!"
|
| 24 |
+
if /i "!onnx_ext!" neq ".onnx" (
|
| 25 |
+
echo ����: �ļ����� .onnx ��ʽ������������·����
|
| 26 |
+
goto loop
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
rem ȥ��Ŀ¼·���ͺ���
|
| 30 |
+
for %%f in ("!onnx_path!") do set "onnx_name=%%~nf"
|
| 31 |
+
echo ��ǰ����Ŀ¼��: %cd%
|
| 32 |
+
rem ִ�� onnx2ncnn.exe
|
| 33 |
+
MNNConvert -f ONNX --modelFile "!onnx_path!" --MNNModel "!onnx_name!.mnn" --bizCode biz --fp16 --info --detectSparseSpeedUp
|
| 34 |
+
|
| 35 |
+
rem ��� fp16 �ļ��Ƿ����
|
| 36 |
+
if exist "!onnx_name!.mnn" (
|
| 37 |
+
rem ����
|
| 38 |
+
) else (
|
| 39 |
+
echo δ���ģ���ļ�������������·����
|
| 40 |
+
goto loop
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
:next_step
|
| 46 |
+
rem ��ʾ�û�ѡ����һ������
|
| 47 |
+
echo.
|
| 48 |
+
echo ��ѡ����һ������:
|
| 49 |
+
echo 1. ��һ��ѭ����Ĭ�ϣ�
|
| 50 |
+
echo 2. ����ģ�ͣ�ʹ�� test.png��
|
| 51 |
+
echo 3. �����Զ���ͼƬ·��������ģ��
|
| 52 |
+
set /p choice="������ѡ�� (1/2/3): "
|
| 53 |
+
if "%choice%"=="" set "choice=1"
|
| 54 |
+
|
| 55 |
+
if "%choice%"=="1" (
|
| 56 |
+
goto loop
|
| 57 |
+
) else if "%choice%"=="2" (
|
| 58 |
+
set "test_image=test.png"
|
| 59 |
+
if exist "!test_image!" (
|
| 60 |
+
mnnsr-ncnn.exe -i "!test_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
|
| 61 |
+
) else (
|
| 62 |
+
echo ����: �ļ� "!test_image!" �����ڣ�������ѡ��
|
| 63 |
+
goto next_step
|
| 64 |
+
)
|
| 65 |
+
) else if "%choice%"=="3" (
|
| 66 |
+
set /p custom_image="�������Զ���ͼƬ·��: "
|
| 67 |
+
if exist "!custom_image!" (
|
| 68 |
+
mnnsr-ncnn.exe -i "!custom_image!" -o "!onnx_name!.png" -m "!onnx_name!.mnn" -s 0
|
| 69 |
+
) else (
|
| 70 |
+
echo ����: �ļ� "!custom_image!" �����ڣ�������ѡ��
|
| 71 |
+
goto next_step
|
| 72 |
+
)
|
| 73 |
+
) else (
|
| 74 |
+
echo ��Ч��ѡ�������ѡ��
|
| 75 |
+
goto next_step
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
rem �������� PNG �ļ��Ƿ����
|
| 79 |
+
if exist "!onnx_name!.png" (
|
| 80 |
+
echo ����ļ����ɳɹ�: "!onnx_name!.png"
|
| 81 |
+
start "" "!onnx_name!.png"
|
| 82 |
+
) else (
|
| 83 |
+
echo ��֤ʧ��: ����ļ������ڣ���ѡ���Ƿ�ɾ��ģ���ļ���
|
| 84 |
+
set /p delete_model="�Ƿ���ģ���ļ� "!onnx_name!.mnn" ? (y/n, Ĭ��n): "
|
| 85 |
+
if /i "!delete_model!"=="y" (
|
| 86 |
+
echo ��������
|
| 87 |
+
) else (
|
| 88 |
+
del "!onnx_name!.mnn"
|
| 89 |
+
echo ģ���ļ���ɾ����
|
| 90 |
+
)
|
| 91 |
+
goto loop
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
rem ѯ���û�ѡ����һ������
|
| 95 |
+
echo.
|
| 96 |
+
echo ��ѡ����һ������:
|
| 97 |
+
echo 1. ������ļ��У�Ĭ�ϣ�
|
| 98 |
+
echo 2. ɾ��ģ���ļ��Ͳ���ͼ
|
| 99 |
+
echo 3. ��һ��ѭ��
|
| 100 |
+
set /p next_choice="������ѡ�� (1/2/3): "
|
| 101 |
+
if "%next_choice%"=="" set "next_choice=1"
|
| 102 |
+
|
| 103 |
+
if "%next_choice%"=="1" (
|
| 104 |
+
start "" .
|
| 105 |
+
) else if "%next_choice%"=="2" (
|
| 106 |
+
del "!onnx_name!.png"
|
| 107 |
+
del "!onnx_name!.mnn"
|
| 108 |
+
echo ģ���ļ��Ͳ���ͼ��ɾ����
|
| 109 |
+
) else if "%next_choice%"=="3" (
|
| 110 |
+
goto loop
|
| 111 |
+
) else (
|
| 112 |
+
echo ��Ч��ѡ�������ѡ��
|
| 113 |
+
goto next_choice
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
:loop
|
| 119 |
+
echo ====================================
|
| 120 |
+
set /p onnx_path="������ .onnx �ļ���·��: "
|
| 121 |
+
goto main
|
pth2onnx.bat
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal enabledelayedexpansion
|
| 3 |
+
|
| 4 |
+
rem ����Ƿ��в�������
|
| 5 |
+
if "%~1"=="" (
|
| 6 |
+
rem ������������ʾ�û������ļ�·��
|
| 7 |
+
set /p pth_path="������ .pth �ļ���·��: "
|
| 8 |
+
) else (
|
| 9 |
+
rem ��������ʹ�õ�һ��������Ϊ�ļ�·��
|
| 10 |
+
set "pth_path=%~1"
|
| 11 |
+
cd /d %~dp0
|
| 12 |
+
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
:main
|
| 16 |
+
rem ����ļ��Ƿ�������Ƿ��� .pth ��β
|
| 17 |
+
if not exist "!pth_path!" (
|
| 18 |
+
echo ����: �ļ������ڣ�����������·����
|
| 19 |
+
goto loop
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
set "ext=!pth_path:~-4!"
|
| 23 |
+
if /i "!ext!" neq ".pth" (
|
| 24 |
+
echo ����: �ļ����� .pth ��ʽ������������·����
|
| 25 |
+
goto loop
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
echo ��ǰ����Ŀ¼��: %cd%
|
| 29 |
+
rem ִ�� pth2onnx.py
|
| 30 |
+
rem python .\pth2onnx.py "!pth_path!" --fp16 --simplify --channel !channel!
|
| 31 |
+
python .\pth2onnx.py "!pth_path!" --fp16
|
| 32 |
+
|
| 33 |
+
:loop
|
| 34 |
+
echo ====================================
|
| 35 |
+
set /p pth_path="������ .pth �ļ���·��: "
|
| 36 |
+
goto main
|
pth2onnx.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import onnx
|
| 6 |
+
from spandrel import ImageModelDescriptor, ModelLoader
|
| 7 |
+
from onnxsim import simplify
|
| 8 |
+
|
| 9 |
+
def convert_pth_to_onnx(pth_path: str, onnx_path: str=None, channel:int=0, tilesize: int = 64, use_fp16: bool=False, simplify_model: bool=False, min_size: int = 1024*1024, output_folder: str=None):
|
| 10 |
+
"""
|
| 11 |
+
Loads a PyTorch model from a .pth file using Spandrel and converts it to ONNX format.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
pth_path: Path to the input .pth model file.
|
| 15 |
+
onnx_path: Path to save the output .onnx file.
|
| 16 |
+
channel: Number of input channels for the model.
|
| 17 |
+
use_fp16: Boolean to determine if the model should be converted to half precision.
|
| 18 |
+
simplify_model: Boolean to determine if the ONNX model should be simplified.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
print(f"Loading model from: {pth_path}")
|
| 22 |
+
try:
|
| 23 |
+
# Use Spandrel to load the model architecture and state dict
|
| 24 |
+
model_descriptor = ModelLoader().load_from_file(pth_path)
|
| 25 |
+
|
| 26 |
+
# Ensure it's the expected type from Spandrel
|
| 27 |
+
if not isinstance(model_descriptor, ImageModelDescriptor):
|
| 28 |
+
print(f"Error: Expected ImageModelDescriptor, but got {type(model_descriptor)}")
|
| 29 |
+
print("Please ensure the .pth file is compatible with Spandrel's loading mechanism.")
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Get the underlying torch.nn.Module
|
| 34 |
+
torch_model = model_descriptor.model
|
| 35 |
+
|
| 36 |
+
# Set the model to evaluation mode (important for dropout, batchnorm layers)
|
| 37 |
+
torch_model.eval()
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Error loading model: {e}")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
if channel == 0:
|
| 44 |
+
channel = model_descriptor.input_channels
|
| 45 |
+
if tilesize<1:
|
| 46 |
+
tilesize = 64
|
| 47 |
+
example_input = torch.randn(1, channel, tilesize, tilesize)
|
| 48 |
+
print("Model input channels:", channel, "tile size:", tilesize)
|
| 49 |
+
|
| 50 |
+
if use_fp16:
|
| 51 |
+
if torch.cuda.is_available():
|
| 52 |
+
torch_model.cuda()
|
| 53 |
+
example_input = example_input.cuda()
|
| 54 |
+
else:
|
| 55 |
+
print("Warning: no CUDA device")
|
| 56 |
+
torch_model.half()
|
| 57 |
+
example_input = example_input.half() # 转换为半精度输入
|
| 58 |
+
print(f"Model loaded successfully: {type(torch_model).__name__}")
|
| 59 |
+
|
| 60 |
+
if output_folder:
|
| 61 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 62 |
+
|
| 63 |
+
if onnx_path is None:
|
| 64 |
+
base_path, _ = os.path.splitext(pth_path)
|
| 65 |
+
if output_folder:
|
| 66 |
+
base_path = os.path.join(output_folder, os.path.basename(base_path))
|
| 67 |
+
|
| 68 |
+
scale = model_descriptor.scale
|
| 69 |
+
# 判断 pth_path 的文件名是否包含 xs 或者 sx,x 为大小写字母 x,s 为 int scale
|
| 70 |
+
filename = os.path.basename(pth_path).upper()
|
| 71 |
+
pattern = f'(^|[_-])({scale}X|X{scale})([_-]|$)'
|
| 72 |
+
if re.search(pattern, filename):
|
| 73 |
+
print(f'文件名 {filename} 包含匹配模式。')
|
| 74 |
+
else:
|
| 75 |
+
base_path = f"{base_path}-x{scale}"
|
| 76 |
+
|
| 77 |
+
onnx_path = base_path + ("-Grayscale" if channel==1 else "") + ("-fp16.onnx" if use_fp16 else ".onnx")
|
| 78 |
+
|
| 79 |
+
# 处理相对路径情况
|
| 80 |
+
# elif output_folder and not os.path.isabs(onnx_path):
|
| 81 |
+
elif output_folder:
|
| 82 |
+
onnx_path = os.path.join(output_folder, onnx_path)
|
| 83 |
+
|
| 84 |
+
print(f"output_folder: {output_folder}, onnx_path: {onnx_path}")
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
# Export the model
|
| 88 |
+
torch.onnx.export(
|
| 89 |
+
torch_model, # The model instance
|
| 90 |
+
example_input, # An example input tensor
|
| 91 |
+
onnx_path, # Where to save the model (file path)
|
| 92 |
+
export_params=True, # Store the trained parameter weights inside the model file
|
| 93 |
+
opset_version=11, # The ONNX version to export the model to (choose based on target runtime)
|
| 94 |
+
do_constant_folding=True, # Whether to execute constant folding for optimization
|
| 95 |
+
input_names=['input'], # The model's input names
|
| 96 |
+
output_names=['output'], # The model's output names
|
| 97 |
+
dynamic_axes={ # Allow variable input/output dimensions
|
| 98 |
+
"input": {0: "batch_size", 2: "height", 3: "width"}, # Batch, H, W can vary
|
| 99 |
+
"output": {0: "batch_size", 2: "height", 3: "width"},# Batch, H, W can vary
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
print(f"ONNX export successful: {onnx_path}")
|
| 103 |
+
|
| 104 |
+
# Optional: Simplify the ONNX model
|
| 105 |
+
if simplify_model:
|
| 106 |
+
model = onnx.load(onnx_path)
|
| 107 |
+
model_simplified, _ = simplify(model)
|
| 108 |
+
onnx.save(model_simplified, onnx_path)
|
| 109 |
+
print(f"ONNX model simplified successfully: {onnx_path}")
|
| 110 |
+
|
| 111 |
+
# 添加文件验证逻辑
|
| 112 |
+
if os.path.exists(onnx_path):
|
| 113 |
+
file_size = os.path.getsize(onnx_path)
|
| 114 |
+
if file_size > min_size:
|
| 115 |
+
return onnx_path
|
| 116 |
+
|
| 117 |
+
os.remove(onnx_path)
|
| 118 |
+
print(f"文件大小不足 {min_size} 字节,已删除无效文件")
|
| 119 |
+
return ""
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"导出失败: {e}")
|
| 123 |
+
return ""
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
import argparse
|
| 127 |
+
parser = argparse.ArgumentParser(description='Convert PyTorch model to ONNX model.')
|
| 128 |
+
parser.add_argument('--pthpath', type=str, required=True, help='Path to the PyTorch model file.')
|
| 129 |
+
parser.add_argument('--onnxpath', type=str, default=None, help='Path to save the ONNX model file.')
|
| 130 |
+
parser.add_argument('--channel', type=int, default=0, help='Channel parameter.')
|
| 131 |
+
parser.add_argument('--tilesize', type=int, default=0, help='Tilesize parameter.')
|
| 132 |
+
parser.add_argument('--fp16', action='store_true', help='Use FP16 precision.')
|
| 133 |
+
parser.add_argument('--simplify', action='store_true', help='Simplify the ONNX model.')
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
success = convert_pth_to_onnx(
|
| 137 |
+
pth_path=args.pth_path,
|
| 138 |
+
onnx_path=args.onnx_path,
|
| 139 |
+
channel=args.channel,
|
| 140 |
+
tilesize=args.tilesize,
|
| 141 |
+
use_fp16=args.use_fp16,
|
| 142 |
+
simplify_model=args.simplify_model
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if success:
|
| 146 |
+
print("Conversion process finished.")
|
| 147 |
+
else:
|
| 148 |
+
print("Conversion process failed.")
|
| 149 |
+
exit(1) # Exit with error code
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spandrel
|
| 2 |
+
torch
|
| 3 |
+
pnnx
|
| 4 |
+
onnx
|
| 5 |
+
mnn
|
| 6 |
+
gradio
|