wli1995 commited on
Commit
a021116
·
verified ·
1 Parent(s): f3ecaba

Upload gradio demo

Browse files
.gitattributes CHANGED
@@ -43,3 +43,4 @@ model_convert/axmodel/espcn_x2_T9_2k.axmodel filter=lfs diff=lfs merge=lfs -text
43
  model_convert/axmodel/620/edsr_x2_small_1.axmodel filter=lfs diff=lfs merge=lfs -text
44
  model_convert/axmodel/620/edsr_x2_small_2.axmodel filter=lfs diff=lfs merge=lfs -text
45
  *.axmodel filter=lfs diff=lfs merge=lfs -text
 
 
43
  model_convert/axmodel/620/edsr_x2_small_1.axmodel filter=lfs diff=lfs merge=lfs -text
44
  model_convert/axmodel/620/edsr_x2_small_2.axmodel filter=lfs diff=lfs merge=lfs -text
45
  *.axmodel filter=lfs diff=lfs merge=lfs -text
46
+ video/gradio.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -38,28 +38,17 @@ Download all files from this repository to the device
38
 
39
  ```
40
 
41
- root@ax650:~/SuperResolution# tree
42
  .
43
- |-- model_convert
44
- | -- axmodel
45
- | `-- edsr_baseline_x2_1.axmodel
46
- | `-- espcn_x2_T9.axmodel
47
- | -- onnx
48
- | `-- edsr_baseline_x2_1.onnx
49
- | `-- espcn_x2_T9.onnx
50
- | `-- build_config_edsr.json
51
- | `-- build_config_espcn.json
52
- |-- python
53
- | `-- run_onnx.py
54
- | `-- run_axmodel.py
55
- | `-- common.py
56
- | `-- imgproc.py
57
- |-- video
58
- | `-- test_1920x1080.mp4
59
- | `-- 1.png
60
- | `-- 2.png
61
-
62
 
 
63
  ```
64
 
65
  ### Requirements
@@ -98,6 +87,7 @@ Average time: 0.373 seconds for each frame
98
  ```
99
 
100
  The output file in `experiment/test_1920x1080_x2.avi`
 
101
 
102
  Output Data:
103
 
@@ -106,4 +96,25 @@ Output Data:
106
  │   └── test_1920x1080_x2.avi
107
  ```
108
 
109
- ![Example Image](video/2.png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  ```
40
 
41
+ root@ax650:~/SuperResolution# tree -L 1
42
  .
43
+ ├── assert
44
+ ├── config.json
45
+ ├── experiment
46
+ ├── model_convert
47
+ ├── python
48
+ ├── README.md
49
+ └── video
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ 6 directories, 2 files
52
  ```
53
 
54
  ### Requirements
 
87
  ```
88
 
89
  The output file in `experiment/test_1920x1080_x2.avi`
90
+ ![Example Image](video/2.png)
91
 
92
  Output Data:
93
 
 
96
  │   └── test_1920x1080_x2.avi
97
  ```
98
 
99
+ #### Inference with M.2 Accelerator card
100
+
101
+ ```bash
102
+ $ cd python
103
+ $ python gradio_demo.py
104
+
105
+ [INFO] Available providers: ['AXCLRTExecutionProvider']
106
+ [INFO] Using provider: AXCLRTExecutionProvider
107
+ [INFO] SOC Name: AX650N
108
+ [INFO] VNPU type: VNPUType.DISABLED
109
+ [INFO] Compiler version: 4.2 6bff2f67
110
+ [INFO] Using provider: AXCLRTExecutionProvider
111
+ [INFO] SOC Name: AX650N
112
+ [INFO] VNPU type: VNPUType.DISABLED
113
+ [INFO] Compiler version: 4.2 6bff2f67
114
+ * Running on local URL: http://0.0.0.0:7860
115
+ * To create a public link, set `share=True` in `launch()`.
116
+ ```
117
+ Then use the M.2 Accelerator card IP instead of the 0.0.0.0, and use chrome open the URL: http://[your ip]:7860
118
+
119
+ ![gradio_demo](./assert/gradio_demo.jpg)
120
+
assert/gradio_demo.jpg ADDED
experiment/gradio_x2.avi ADDED
Binary file (32.2 kB). View file
 
python/__pycache__/common.cpython-313.pyc ADDED
Binary file (4.47 kB). View file
 
python/__pycache__/imgproc.cpython-313.pyc ADDED
Binary file (32.1 kB). View file
 
python/gradio_demo.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import tempfile
6
+ import time
7
+ import axengine as axe
8
+ import common
9
+ import imgproc
10
+
11
+ rgb_range=255
12
+ scale=2
13
+ def from_numpy(x):
14
+ return x if isinstance(x, np.ndarray) else np.array(x)
15
+
16
+ def quantize(img, rgb_range):
17
+ pixel_range = 255 / rgb_range
18
+ return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range
19
+
20
+ # 初始化EDSR和ESPCN模型
21
+ def init_SRmodel(EDSR_path="../model_convert/axmodel/edsr_baseline_x2_1.axmodel",
22
+ ESPCN_path="../model_convert/axmodel/espcn_x2_T9.axmodel"):
23
+
24
+ EDSR_session = axe.InferenceSession(EDSR_path)
25
+ ESPCN_session = axe.InferenceSession(ESPCN_path)
26
+
27
+ return [EDSR_session, ESPCN_session]
28
+
29
+ SR_sessions=init_SRmodel()
30
+
31
+ def EDSR_infer(frame, EDSR_session=SR_sessions[0]):
32
+ output_names = [x.name for x in EDSR_session.get_outputs()]
33
+ input_name = EDSR_session.get_inputs()[0].name
34
+
35
+ lr_y_image, = common.set_channel(frame, n_channels=3)
36
+ lr_y_image, = common.np_prepare(lr_y_image, rgb_range=rgb_range)
37
+
38
+ sr = EDSR_session.run(output_names, {input_name: lr_y_image})
39
+
40
+ if isinstance(sr, (list, tuple)):
41
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
42
+ else:
43
+ sr = from_numpy(sr)
44
+
45
+ sr = quantize(sr, rgb_range).squeeze(0)
46
+ normalized = sr * 255 / rgb_range
47
+ ndarr = normalized.transpose(1, 2, 0).astype(np.uint8)
48
+
49
+ return ndarr
50
+
51
+ def ESPCN_infer(frame, ESPCN_session=SR_sessions[1]):
52
+
53
+ output_names = [x.name for x in ESPCN_session.get_outputs()]
54
+ input_name = ESPCN_session.get_inputs()[0].name
55
+
56
+ lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(frame)
57
+ bic_cb_image = cv2.resize(lr_cb_image,
58
+ (int(lr_cb_image.shape[1] * scale),
59
+ int(lr_cb_image.shape[0] * scale)),
60
+ interpolation=cv2.INTER_CUBIC)
61
+ bic_cr_image = cv2.resize(lr_cr_image,
62
+ (int(lr_cr_image.shape[1] * scale),
63
+ int(lr_cr_image.shape[0] * scale)),
64
+ interpolation=cv2.INTER_CUBIC)
65
+
66
+ sr = ESPCN_session.run(output_names, {input_name: lr_y_image})
67
+
68
+ if isinstance(sr, (list, tuple)):
69
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
70
+ else:
71
+ sr = from_numpy(sr)
72
+
73
+ ndarr = imgproc.array_to_image(sr)
74
+ sr_y_image = ndarr.astype(np.float32) / 255.0
75
+ sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image])
76
+ sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image)
77
+ sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8)
78
+
79
+ return sr_image
80
+
81
+ # ======================
82
+ # 模拟超分辨率模型
83
+ # ======================
84
+ def EDSR_MODEL(input_data, is_video=False):
85
+
86
+ if is_video:
87
+ output_frames = []
88
+ for frame in input_data:
89
+
90
+ out = EDSR_infer(frame=frame)
91
+ output_frames.append(out)
92
+ return output_frames
93
+ else:
94
+ out = EDSR_infer(frame=input_data)
95
+ return out
96
+
97
+ def ESPCN_MODEL(input_data, is_video=False):
98
+ if is_video:
99
+ output_frames = []
100
+ for frame in input_data:
101
+ out = ESPCN_infer(frame=frame)
102
+ output_frames.append(out)
103
+ return output_frames
104
+ else:
105
+ out = ESPCN_infer(frame=input_data)
106
+ return out
107
+
108
+ # ======================
109
+ # 全局状态(单用户)
110
+ # ======================
111
+ class AppState:
112
+ def __init__(self):
113
+ self.original_img = None # 原始图(BGR, 高分辨率)
114
+ self.sr_img = None # 超分图(BGR, 高分辨率)
115
+ self.is_video = False
116
+
117
+ app_state = AppState()
118
+
119
+ # ======================
120
+ # 核心处理函数
121
+ # ======================
122
+ def process_super_resolution(input_file, model_choice):
123
+ global app_state
124
+ if input_file is None:
125
+ raise gr.Error("请先上传图片或视频!")
126
+
127
+ file_path = input_file
128
+ app_state = AppState()
129
+ info_text = ""
130
+
131
+ is_video = any(ext in file_path.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv'])
132
+
133
+ if is_video:
134
+ # --- 视频处理(直接保存高分辨率)---
135
+ cap = cv2.VideoCapture(file_path)
136
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
137
+ fps = cap.get(cv2.CAP_PROP_FPS)
138
+ info_text += f"🎬 视频信息:\n- 总帧数: {total_frames}\n- 帧率: {fps:.2f} FPS\n"
139
+ frames = []
140
+ while True:
141
+ ret, frame = cap.read()
142
+ if not ret:
143
+ break
144
+ frames.append(frame)
145
+ cap.release()
146
+
147
+ model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL
148
+ start_time = time.time()
149
+ output_data = model_func(frames, is_video=True)
150
+ infer_time = time.time() - start_time
151
+ info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n"
152
+
153
+ full_video_path = os.path.join(tempfile.gettempdir(), f"sr_video_x2.mp4")
154
+ h_out, w_out = output_data[0].shape[:2]
155
+ info_text += f"- 超分后尺寸: {w_out} x {h_out}\n"
156
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
157
+ out_video = cv2.VideoWriter(full_video_path, fourcc, fps, (w_out, h_out))
158
+ for frame in output_data:
159
+ out_video.write(frame)
160
+ out_video.release()
161
+
162
+ app_state.is_video = True
163
+
164
+ return (
165
+ gr.update(value=None, visible=False), # image_display
166
+ gr.update(visible=False), # btn_original
167
+ gr.update(visible=False), # btn_sr
168
+ gr.update(value="当前: 无", visible=False),
169
+ gr.update(value=full_video_path, visible=True),
170
+ gr.update(value=full_video_path, visible=True),
171
+ gr.update(visible=False),
172
+ info_text
173
+ )
174
+
175
+ else:
176
+ # --- 图片处理(保存原始高分辨率)---
177
+ img = cv2.imread(file_path)
178
+ if img is None:
179
+ raise gr.Error("无法读取图片!")
180
+ h, w = img.shape[:2]
181
+ info_text += f"🖼️ 图片信息:\n- 原始尺寸: {w} x {h}\n"
182
+
183
+ app_state.original_img = img.copy()
184
+ model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL
185
+ start_time = time.time()
186
+ sr_img = model_func(img, is_video=False)
187
+ infer_time = time.time() - start_time
188
+ info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n"
189
+
190
+ h_out, w_out = sr_img.shape[:2]
191
+ info_text += f"- 超分后尺寸: {w_out} x {h_out}\n"
192
+
193
+ sr_img_path = os.path.join(tempfile.gettempdir(), f"sr_image_x2.png")
194
+ cv2.imwrite(sr_img_path, sr_img)
195
+ app_state.sr_img = sr_img
196
+
197
+ app_state.is_video = False
198
+
199
+ # 默认显示原图(高分辨率,但 UI 会限制尺寸)
200
+ return (
201
+ gr.update(value=app_state.original_img[:, :, ::-1], visible=True), # BGR→RGB
202
+ gr.update(visible=True),
203
+ gr.update(visible=True),
204
+ gr.update(value="当前: 原图", visible=True),
205
+ gr.update(visible=False),
206
+ gr.update(visible=False),
207
+ gr.update(value=sr_img_path, visible=True),
208
+ info_text
209
+ )
210
+
211
+ # ======================
212
+ # 切换显示函数(直接使用原始高分辨率图)
213
+ # ======================
214
+ def show_original():
215
+ if app_state.original_img is None:
216
+ return gr.update(), gr.update()
217
+ # OpenCV BGR → RGB
218
+ rgb_img = app_state.original_img[:, :, ::-1]
219
+ return gr.update(value=rgb_img), gr.update(value="当前: 原图")
220
+
221
+ def show_sr():
222
+ if app_state.sr_img is None:
223
+ return gr.update(), gr.update()
224
+ rgb_img = app_state.sr_img[:, :, ::-1]
225
+ return gr.update(value=rgb_img), gr.update(value="当前: 超分图")
226
+
227
+ # ======================
228
+ # Gradio UI
229
+ # ======================
230
+ with gr.Blocks(title="超分辨率可视化工具", theme=gr.themes.Soft()) as demo:
231
+ gr.Markdown("## 🚀 超分辨率模型效果可视化")
232
+ gr.Markdown("上传图片或视频,选择模型,点击箭头切换原图/超分图!")
233
+
234
+ input_file = gr.File(
235
+ label="📂 上传图片或视频",
236
+ file_types=["image", "video"],
237
+ file_count="single"
238
+ )
239
+
240
+ with gr.Row():
241
+ model_choice = gr.Radio(
242
+ choices=["EDSR_MODEL", "ESPCN_MODEL"],
243
+ value="EDSR_MODEL",
244
+ label="🔍 选择超分辨率模型"
245
+ )
246
+ run_btn = gr.Button("🚀 开始超分", variant="primary")
247
+
248
+ # 图片区:硬性限定尺寸,直接显示原始高分辨率图
249
+ with gr.Column(visible=False) as image_section:
250
+ image_label = gr.Textbox(value="当前: 原图", interactive=False, lines=1)
251
+ image_display = gr.Image(
252
+ label="🖼️ 图像显示",
253
+ width=800, # 👈 固定宽度
254
+ height=600 # 👈 固定高度
255
+ )
256
+ with gr.Row():
257
+ btn_original = gr.Button("◀ 原图")
258
+ btn_sr = gr.Button("超分图 ▶")
259
+
260
+ # 视频区:硬性限定高度
261
+ output_video_player = gr.Video(
262
+ label="▶️ 超分视频(高分辨率)",
263
+ visible=False,
264
+ height=450 # 宽度自适应,高度固定
265
+ )
266
+
267
+ with gr.Row():
268
+ download_image = gr.File(label="📥 下载超分图片(原图)", visible=False)
269
+ download_video = gr.File(label="📥 下载超分视频(完整分辨率)", visible=False)
270
+
271
+ info_box = gr.Textbox(label="📊 处理信息", lines=6, interactive=False)
272
+
273
+ run_btn.click(
274
+ fn=process_super_resolution,
275
+ inputs=[input_file, model_choice],
276
+ outputs=[
277
+ image_display,
278
+ btn_original,
279
+ btn_sr,
280
+ image_label,
281
+ output_video_player,
282
+ download_video,
283
+ download_image,
284
+ info_box
285
+ ]
286
+ )
287
+
288
+ btn_original.click(show_original, outputs=[image_display, image_label])
289
+ btn_sr.click(show_sr, outputs=[image_display, image_label])
290
+
291
+ def toggle_ui(file):
292
+ if file is None:
293
+ return (
294
+ gr.update(visible=False),
295
+ gr.update(visible=False),
296
+ gr.update(visible=False),
297
+ gr.update(visible=False)
298
+ )
299
+ if any(ext in file.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']):
300
+ return (
301
+ gr.update(visible=False),
302
+ gr.update(visible=False),
303
+ gr.update(visible=True),
304
+ gr.update(visible=True)
305
+ )
306
+ else:
307
+ return (
308
+ gr.update(visible=True),
309
+ gr.update(visible=True),
310
+ gr.update(visible=False),
311
+ gr.update(visible=False)
312
+ )
313
+
314
+ input_file.change(
315
+ fn=toggle_ui,
316
+ inputs=input_file,
317
+ outputs=[
318
+ image_section,
319
+ download_image,
320
+ output_video_player,
321
+ download_video
322
+ ]
323
+ )
324
+
325
+ if __name__ == "__main__":
326
+ demo.launch(server_name="0.0.0.0", server_port=7860)
video/gradio.png ADDED

Git LFS Details

  • SHA256: 727a7d51411a1c32525a2c42662758ae952915b62ff06c7fb18488ffbdba3b22
  • Pointer size: 132 Bytes
  • Size of remote file: 2.07 MB