Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| from gradio.themes import default | |
| import requests | |
| import shutil | |
| import uuid | |
| from ftplib import FTP | |
| from spandrel import ImageModelDescriptor, ModelLoader | |
| import torch | |
| import subprocess | |
| import pkg_resources | |
| print("Torch version:",torch.__version__) | |
| print("Gradio version:",gr.__version__) | |
| mnn_version = pkg_resources.get_distribution("MNN").version | |
| print("MNN version:", mnn_version) | |
| spandrel_version = pkg_resources.get_distribution("spandrel").version | |
| print("Spandrel version:", spandrel_version) | |
| pnnx_version = pkg_resources.get_distribution("pnnx").version | |
| print("PNNX version:", pnnx_version) | |
| # 定义 downloaded_files 变量 | |
| downloaded_files = {} | |
| # 新增日志开关 | |
| log_to_terminal = True | |
| # 新增全局任务计数器 | |
| task_counter = 0 | |
| # 新增日志函数 | |
| def print_log(task_id, filename, stage, status): | |
| if log_to_terminal: | |
| print(f"任务{task_id}: {filename}, [{status}] {stage}") | |
| # 修改 start_process 函数,处理新增输入 | |
| def start_process(input_file, input_url, input2, shape0_str, shape1_str, output_type, input_suffix=".pth"): | |
| global task_counter | |
| task_counter += 1 | |
| task_id = task_counter | |
| input1 = input_file if input_file else input_url | |
| print_log(task_id, input2, input1, "input1") | |
| print_log(task_id, input2, str(output_type), "output_type") | |
| log = "转换过程非常慢,请耐心等待。显示文件列表不代表转换完成。如果未发生错误,转换结束会显示”任务完成“\n" | |
| output_files = [] | |
| yield [], log | |
| if input2 == None or input2.strip() == "": | |
| split_input = os.path.splitext(os.path.basename(input1)) | |
| if len(split_input) > 1: | |
| suffix = split_input[1].split('?')[0].lower() | |
| if suffix not in [".pth" , ".safetensors" , ".ckpt"]: | |
| print_log(task_id, input2, "不支持此文件的格式 suffix="+suffix, "错误") | |
| log += f"不支持此文件的格式\n" | |
| return [] , log | |
| input2 = split_input[0] | |
| print_log(task_id, input2, "检查文件名", "开始") | |
| log += f"检查文件名…\n" | |
| yield [], log | |
| if input2 == None or input2.strip() == "": | |
| input2 = str(task_id) | |
| log += f"未提供文件名,使用{input2}\n" | |
| print_log(task_id, input2, f"未提供文件名,使用{input2}", "修正") | |
| yield [], log | |
| try: | |
| # 判断 input1 是地址还是文件,增加对 ftp 和 webdav 协议的支持 | |
| supported_protocols = ('http://', 'https://', 'ftp://', 'webdav://') | |
| if isinstance(input1, str) and input1.startswith(supported_protocols): | |
| url = input1 | |
| if url in downloaded_files and os.path.exists(downloaded_files[url]): | |
| file_path = downloaded_files[url] | |
| print_log(task_id, input2, "检查下载状态", "跳过下载") | |
| log += f"跳过下载,文件已存在: {file_path}\n" | |
| yield [], log | |
| else: | |
| print_log(task_id, input2, "下载文件", "开始") | |
| log += f"开始下载文件…\n" | |
| yield [], log | |
| # 生成唯一文件名 | |
| file_name = str(task_id) + input_suffix | |
| file_path = os.path.join(os.getcwd(), file_name) | |
| if url.startswith('ftp://'): | |
| try: | |
| # 解析 ftp 地址 | |
| parts = url.replace('ftp://', '').split('/') | |
| host = parts[0] | |
| remote_file_path = '/'.join(parts[1:]) | |
| ftp = FTP(host) | |
| ftp.login() | |
| with open(file_path, 'wb') as f: | |
| ftp.retrbinary('RETR ' + remote_file_path, f.write) | |
| ftp.quit() | |
| downloaded_files[url] = file_path | |
| print_log(task_id, input2, "下载文件", "成功") | |
| log += f"文件下载成功: {file_path}\n" | |
| yield [], log | |
| except Exception as e: | |
| print_log(task_id, input2, "下载文件", f"失败 (FTP): {str(e)}") | |
| log += f"FTP 文件下载失败: {str(e)}\n" | |
| yield [], log | |
| return | |
| else: | |
| if url.startswith(('http://', 'https://')): | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| with open(file_path, 'wb') as f: | |
| f.write(response.content) | |
| downloaded_files[url] = file_path | |
| print_log(task_id, input2, "下载文件", "成功") | |
| log += f"文件下载成功: {file_path}\n" | |
| yield [], log | |
| else: | |
| print_log(task_id, input2, f"下载文件(HTTP): {response.status_code}", "失败") | |
| log += f"文件下载失败,状态码: {response.status_code}\n" | |
| yield [], log | |
| return | |
| elif input1 is not None: | |
| print("check file" , input1, os.path.exists(input1)) | |
| file_path = input1 | |
| log += f"使用上传的文件: {file_path}\n" | |
| print_log(task_id, input2, "使用上传文件", "开始") | |
| yield [], log | |
| else: | |
| log += "未提供有效文件或地址\n" | |
| print_log(task_id, input2, "检查文件输入", "失败 (无有效输入)") | |
| yield [], log | |
| return | |
| # 检查文件大小 | |
| try: | |
| file_size = os.path.getsize(file_path) / 1024 /1024 # 转换为 KB | |
| if file_size > 200 : | |
| log += f"文件太大,建议 200MB 以内,当前文件大小为 {file_size } MB。\n" | |
| print_log(task_id, input2, "文件太大("+ file_size +"MB)", "失败") | |
| yield [], log | |
| return | |
| except Exception as e: | |
| log += f"获取文件大小失败: {str(e)}\n" | |
| print_log(task_id, input2, "检查文件大小", f"失败: {str(e)}") | |
| yield [], log | |
| return | |
| # 生成新文件夹用于暂存结果 | |
| output_folder = os.path.join(os.getcwd(), str(uuid.uuid4())) | |
| os.makedirs(output_folder, exist_ok=True) | |
| print_log(task_id, input2, "创建临时文件夹", "完成") | |
| log += f"创建临时文件夹: {output_folder}\n生成张量\n" | |
| yield [], log | |
| # 解析输入的字符串为数组 | |
| try: | |
| # 尝试解析 shape0_str | |
| shape0 = [int(x) for x in shape0_str.split(',')] if shape0_str else [0, 0, 0, 0] | |
| # 检查 shape0 是否为 4 个元素,如果不是则设置为全 0 | |
| if len(shape0) != 4: | |
| shape0 = [0, 0, 0, 0] | |
| # 尝试解析 shape1_str | |
| shape1 = [int(x) for x in shape1_str.split(',')] if shape1_str else [0, 0, 0, 0] | |
| # 检查 shape1 是否为 4 个元素,如果不是则设置为全 0 | |
| if len(shape1) != 4: | |
| shape1 = [0, 0, 0, 0] | |
| except ValueError: | |
| # 如果解析过程中出现 ValueError,将 shape0 和 shape1 设置为全 0 | |
| shape0 = [0, 0, 0, 0] | |
| shape1 = [0, 0, 0, 0] | |
| log += f"输入的 shape 字符串格式不正确,请使用逗号分隔的整数。shape0_str={shape0_str},shape1_str={shape1_str}\n" | |
| yield [], log | |
| return | |
| # 以下是 process_file 函数的代码 | |
| # 使用 torch.rand 生成 input_shape | |
| print_log(task_id, input2, "生成输入张量", "开始") | |
| log += "生成张量…\n" | |
| yield [], log | |
| output_base = output_folder + "/" + input2 | |
| pt_path = output_base + ".pt" | |
| command = f"pnnx {pt_path}" | |
| input_tensor0 = torch.rand(shape0) if any(shape0) else None | |
| input_tensor1 = torch.rand(shape1) if any(shape1) else None | |
| if input_tensor0 is not None and input_tensor1 is not None: | |
| example_input = (input_tensor0, input_tensor1) | |
| # 修改此处,去除 shape 字符串中的空格 | |
| if "Fixed" in output_type: | |
| command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')} inputshape2={str(shape1).replace(' ', '')}" | |
| elif input_tensor0 is not None: | |
| example_input = input_tensor0 | |
| if "Fixed" in output_type: | |
| command = f"pnnx {pt_path} inputshape={str(shape0).replace(' ', '')}" | |
| else: | |
| example_input = input_tensor1 | |
| command = f"pnnx {pt_path}" | |
| input_tensor_str = "" | |
| if input_tensor0 is not None: | |
| input_tensor_str += str(input_tensor0.shape) | |
| else: | |
| input_tensor_str += "None" | |
| if input_tensor1 is not None: | |
| input_tensor_str += ", " + str(input_tensor1.shape) | |
| else: | |
| input_tensor_str += ", None" | |
| print_log(task_id, input2, "生成输入张量"+input_tensor_str, "完成") | |
| log +=input_tensor_str+ "\n" | |
| yield [], log | |
| # 确保 output_folder 存在 | |
| if not os.path.exists(output_folder): | |
| os.makedirs(output_folder) | |
| print_log(task_id, input2, "加载模型", "开始") | |
| log += "加载模型…\n" | |
| yield [], log | |
| # load a model from disk | |
| model = ModelLoader().load_from_file(file_path) | |
| # make sure it's an image to image model | |
| assert isinstance(model, ImageModelDescriptor) | |
| print_log(task_id, input2, "获得模型对象", "开始") | |
| log += "获得模型对象…\n" | |
| yield [], log | |
| # send it to the GPU and put it in inference mode | |
| # model.cuda().eval() | |
| model.eval() | |
| torch_model = model.model | |
| print_log(task_id, input2, "获得模型对象", "完成") | |
| yield [], log | |
| width_ratio = 0 | |
| if os.path.exists(pt_path): | |
| print_log(task_id, input2, "转换为TorchScript模型", "跳过") | |
| log += "跳过转换为TorchScript模型\n" | |
| yield [], log | |
| elif "TorchScript" in output_type: | |
| print_log(task_id, input2, "转换为TorchScript模型", "开始") | |
| log+= "转换为TorchScript模型…\n" | |
| yield [], log | |
| # 使用 torch.jit.trace 进行模型转换 | |
| traced_torch_model = torch.jit.trace(torch_model, example_input) | |
| traced_torch_model.save(output_folder + "/" + input2 + ".pt") | |
| print_log(task_id, input2, "转换为TorchScript模型", "完成") | |
| # 获取输出 | |
| example_output = traced_torch_model(example_input) | |
| if isinstance(example_output, torch.Tensor): | |
| width_ratio = example_output.shape[2] / example_input.shape[2] | |
| print_log(task_id, input2, "获得缩放倍率="+ str(width_ratio)+", 输出shape="+str(list(example_output.shape)), "完成") | |
| log+= ("获得缩放倍率="+str(width_ratio)+", 输出shape="+str(list(example_output.shape))+"\n") | |
| yield [], log | |
| else: | |
| print_log(task_id, input2, "Traced torch model输出" + type(example_output), "错误") | |
| log+="Traced torch model输出" + type(example_output)+ "错误\n" | |
| yield [], log | |
| scale = int(width_ratio) | |
| # 转换为 ONNX 模型 | |
| if "ONNX" in output_type or "NCNN" in output_type or "MNN" in output_type: | |
| if str(scale) in input2 or scale <1: | |
| onnx_path = output_base + ".onnx" | |
| else: | |
| onnx_path = output_base + "-x" + str(scale) + ".onnx" | |
| if os.path.exists(onnx_path): | |
| print_log(task_id, input2, "转换为ONNX模型", "跳过") | |
| log += "跳过转换为ONNX模型\n" | |
| yield [], log | |
| else: | |
| print_log(task_id, input2, "转换为ONNX模型", "开始") | |
| log += "转换为ONNX模型…\n" | |
| yield [], log | |
| torch.onnx.export(torch_model, example_input, onnx_path, opset_version=17, input_names=["input"], output_names=["output"]) | |
| # 转换为 mnn 模型 | |
| if "MNN" in output_type: | |
| if str(scale) in input2 or scale < 1: | |
| mnn_path = output_base + ".mnn" | |
| else: | |
| mnn_path = output_base + "-x" + str(scale) + ".mnn" | |
| mnn_config = "" | |
| if "Fixed" in output_type and input_tensor0 is not None: | |
| mnn_config = output_base + ".mnnconfig" | |
| with open(mnn_config, 'w') as f: | |
| if input_tensor1 is not None: | |
| f.write(f"input_names = input0, input1\n") | |
| f.write(f"input_dims = {'x'.join(map(str, shape0))}, {'x'.join(map(str, shape0))},\n") | |
| else: | |
| f.write(f"input_names = input\n") | |
| f.write(f"input_dims = {'x'.join(map(str, shape0))}\n") | |
| if os.path.exists(mnn_path): | |
| print_log(task_id, input2, "转换为MNN模型", "跳过") | |
| log += "跳过转换为MNN模型\n" | |
| yield [], log | |
| else: | |
| print_log(task_id, input2, "转换为MNN模型", "开始") | |
| log += "转换为MNN模型…\n" | |
| mnn_command = f"MNNConvert -f ONNX --modelFile \"{onnx_path}\" --MNNModel \"{mnn_path}\" --bizCode biz --fp16 --info --detectSparseSpeedUp" | |
| if mnn_config: | |
| mnn_command += f" --inputConfigFile \"{mnn_config}\"" | |
| try: | |
| # 使用 subprocess.Popen 执行命令 | |
| process = subprocess.Popen(mnn_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| while True: | |
| output = process.stdout.readline() | |
| if output == '' and process.poll() is not None: | |
| break | |
| if output: | |
| # | |
| if log_to_terminal: | |
| print(output.strip()) | |
| log += output.strip() + '\n' | |
| yield [], log | |
| returncode = process.poll() | |
| if returncode != 0: | |
| print_log(task_id, input2, f"转换为MNN模型,返回码: {returncode},命令: {mnn_command} ", "错误") | |
| log += f"执行mnn命令失败,返回码: {returncode},命令: {mnn_command} \n" | |
| else: | |
| log += f"执行mnn命令成功: {mnn_command} \n" | |
| except Exception as e: | |
| log += f"执行mnn命令: {mnn_command} 失败,错误信息: {str(e)}\n" | |
| print_log(task_id, input2, f"转换为MNN模型,错误信息: {str(e)}", "错误") | |
| if "NCNN" in output_type: | |
| print_log(task_id, input2, "执行ncnn命令" + command, "开始") | |
| log += "执行ncnn命令…\n" | |
| yield [], log | |
| try: | |
| # 使用 subprocess.Popen 执行命令 | |
| process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| while True: | |
| output = process.stdout.readline() | |
| if output == '' and process.poll() is not None: | |
| break | |
| if output: | |
| # | |
| if log_to_terminal: | |
| print(output.strip()) | |
| log += output.strip() + '\n' | |
| yield [], log | |
| returncode = process.poll() | |
| if returncode != 0: | |
| log += f"执行ncnn命令失败,返回码: {returncode},命令: {command} \n" | |
| print_log(task_id, input2, f"返回码: {returncode},命令: {command} ", "失败") | |
| else: | |
| log += f"执行ncnn命令成功: {command} \n" | |
| except Exception as e: | |
| log += f"执行ncnn命令: {command} 失败,错误信息: {str(e)}\n" | |
| print_log(task_id, input2, f"错误信息: {str(e)}", "错误") | |
| # 查找 output_folder 目录下以 .ncnn.bin 和 .ncnn.param 结尾的文件 | |
| bin_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.bin')] | |
| param_files = [f for f in os.listdir(output_folder) if f.endswith('.ncnn.param')] | |
| if bin_files and param_files: | |
| param_file = os.path.join(output_folder, param_files[0]) | |
| bin_file = os.path.join(output_folder, bin_files[0]) | |
| import zipfile | |
| # 压缩包名称 | |
| zip_file_name = os.path.join(output_folder, f"models-{input2}.zip") | |
| # 压缩包内文件夹名称 | |
| zip_folder_name = f"models-{input2}" | |
| # 重命名后的文件名 | |
| new_bin_name = f"x{scale}.bin" | |
| new_param_name = f"x{scale}.param" | |
| # 创建压缩包 | |
| with zipfile.ZipFile(zip_file_name, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| # 写入重命名后的.bin文件 | |
| zipf.write(bin_file, os.path.join(zip_folder_name, new_bin_name)) | |
| # 写入重命名后的.param文件 | |
| zipf.write(param_file, os.path.join(zip_folder_name, new_param_name)) | |
| log += f"已创建压缩包: {zip_file_name}\n" | |
| print_log(task_id, input2, "创建压缩包"+zip_file_name, "完成") | |
| yield [], log | |
| else: | |
| log += f"未找到 ncnn 文件\n" | |
| print_log(task_id, input2, "查找 ncnn 文件", "失败") | |
| yield [], log | |
| output_files = [os.path.join(output_folder, f) for f in os.listdir(output_folder) if os.path.isfile(os.path.join(output_folder, f))] | |
| log += f"任务完成\n" | |
| print_log(task_id, input2, "执行命令", "完成") | |
| yield output_files, log | |
| except Exception as e: | |
| log += f"发生错误: {e}\n" | |
| print_log(task_id, input2, e , f"失败") | |
| yield [], log | |
| # 创建 Gradio 界面 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("文件处理界面") | |
| with gr.Row(): | |
| # 左侧列,包含输入组件和按钮 | |
| with gr.Column(): | |
| # 添加文本提示 | |
| gr.Markdown("请输入的url,或者上传一个文件。限制文件为小于100M的*.pth模型") | |
| with gr.Row(): | |
| input1 = gr.Textbox(label="粘贴地址") | |
| # 新增文件上传组件 | |
| input1_file = gr.File(label="上传文件", file_types=[".pth", ".safetensors", ".ckpt"]) | |
| with gr.Row(): | |
| input2 = gr.Textbox(label="自定义文件名") | |
| output_type = gr.Dropdown( | |
| choices=["TorchScript", "ONNX", "Fixed", "MNN", "NCNN"], | |
| value=["TorchScript", "ONNX", "MNN", "NCNN"], | |
| multiselect=True, | |
| label="模型类型", | |
| info="1. 生成mnn和ncnn模型必须先生成onnx模型;2.如果选项中包含了Fixed,那么输出的onnx和mnn模型都使用固定shape的input。" | |
| ) | |
| shape0_str = gr.Textbox(label="shape0 (逗号分隔的整数)", value="1,3,128,128") | |
| shape1_str = gr.Textbox(label="shape1 (逗号分隔的整数)", value="0,0,0,0") | |
| with gr.Row(): | |
| start_button = gr.Button("开始") | |
| # 添加取消按钮 | |
| cancel_button = gr.Button("取消") | |
| # 右侧列,包含输出组件和日志文本框 | |
| with gr.Column(): | |
| output = gr.File(label="输出文件", file_count="multiple") | |
| log_textbox = gr.Textbox(label="日志", lines=10, interactive=False) | |
| # 绑定事件,修改输入参数 | |
| process = start_button.click( | |
| fn=start_process, | |
| inputs=[input1_file, input1, input2, shape0_str, shape1_str, output_type], | |
| outputs=[output, log_textbox] | |
| ) | |
| # 为取消按钮添加点击事件绑定,使用 cancels 属性取消 start_process 任务 | |
| cancel_button.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[process] | |
| ) | |
| # 添加范例 | |
| examples = [ | |
| [None, "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "NCNN"]], | |
| [None, "https://github.com/Phhofm/models/releases/download/4xNomos8kSC/4xNomos8kSC.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "NCNN"]], | |
| [None, "https://github.com/Phhofm/models/releases/download/1xDeJPG/1xDeJPG_SRFormer_light.pth", "", "1,3,128,128", "0,0,0,0", ["TorchScript", "ONNX", "MNN", "NCNN"]], | |
| [None, "https://objectstorage.us-phoenix-1.oraclecloud.com/n/ax6ygfvpvzka/b/open-modeldb-files/o/4x-WTP-ColorDS.pth", "", "1,3,128,128", "0,0,0,0", ["ONNX", "MNN"]], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input1_file, input1, input2, shape0_str, shape1_str, output_type], | |
| # outputs=[output, log_textbox], | |
| # fn=start_process, | |
| ) | |
| demo.launch(ssr_mode=False, server_name="0.0.0.0") |