import gradio as gr from pysr import SR from pyaxdev import enum_devices, sys_init, sys_deinit, AxDeviceType import argparse if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='edsr_x2_small_1.axmodel') args = parser.parse_args() devices_info = enum_devices() dev_type = None devid = None print("可用设备:", devices_info) if devices_info['host']['available']: print("host device available") sys_init(AxDeviceType.host_device, -1) dev_type = AxDeviceType.host_device devid = -1 elif devices_info['devices']['count'] > 0: print("axcl device available, use device-0") sys_init(AxDeviceType.axcl_device, 0) dev_type = AxDeviceType.axcl_device devid = 0 else: raise Exception("No available device") sr = SR({ 'dev_type': dev_type, 'devid': devid, 'model_path': args.model }) def run_super_resolution(image): # image 是 numpy 数组 (H, W, 3),BGR 格式 if image is None: return None sr_image = sr(image) return (image, sr_image) # 返回两个图,给ImageComparator显示 with gr.Blocks() as demo: gr.Markdown("## 图像超分对比 Demo") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="上传图片", type="numpy") run_button = gr.Button("运行超分") with gr.Column(scale=4): output_comparator = gr.ImageSlider(height=800) run_button.click(fn=run_super_resolution, inputs=input_image, outputs=output_comparator) # 启动 ip = "0.0.0.0" demo.launch(server_name=ip, server_port=7860) import atexit if devices_info['host']['available']: atexit.register(lambda: sys_deinit(AxDeviceType.host_device, -1)) elif devices_info['devices']['count'] > 0: atexit.register(lambda: sys_deinit(AxDeviceType.axcl_device, 0))