File size: 2,088 Bytes
e3513f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from pysr import SR
from pyaxdev import enum_devices, sys_init, sys_deinit, AxDeviceType
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='edsr_x2_small_1.axmodel')
    args = parser.parse_args()

    devices_info = enum_devices()
    dev_type = None
    devid = None
    print("可用设备:", devices_info)
    if devices_info['host']['available']:
        print("host device available")
        sys_init(AxDeviceType.host_device, -1)
        dev_type = AxDeviceType.host_device
        devid = -1
    elif devices_info['devices']['count'] > 0:
        print("axcl device available, use device-0")
        sys_init(AxDeviceType.axcl_device, 0)
        dev_type = AxDeviceType.axcl_device
        devid = 0
    else:
        raise Exception("No available device")
    
    sr = SR({
            'dev_type': dev_type,
            'devid': devid,
            'model_path': args.model
        })


 
    def run_super_resolution(image):
        # image 是 numpy 数组 (H, W, 3),BGR 格式
        if image is None:
            return None
        sr_image = sr(image)
        return (image, sr_image)  # 返回两个图,给ImageComparator显示

    with gr.Blocks() as demo:
        gr.Markdown("## 图像超分对比 Demo")
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(label="上传图片", type="numpy")
                run_button = gr.Button("运行超分")
            with gr.Column(scale=4):
                output_comparator = gr.ImageSlider(height=800)

        run_button.click(fn=run_super_resolution, inputs=input_image, outputs=output_comparator)

    
    # 启动
    ip = "0.0.0.0"
    demo.launch(server_name=ip, server_port=7860)
    
    

    import atexit
    if devices_info['host']['available']:
        atexit.register(lambda: sys_deinit(AxDeviceType.host_device, -1))
    elif devices_info['devices']['count'] > 0:
        atexit.register(lambda: sys_deinit(AxDeviceType.axcl_device, 0))