wli1995 commited on
Commit
481945d
·
verified ·
1 Parent(s): e44b144

Upload gradio demo

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +19 -0
  3. assert/gradio_demo.JPG +3 -0
  4. python/gradio_demo.py +178 -0
.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  image/sr_colorize.jpg filter=lfs diff=lfs merge=lfs -text
37
  model/colorize_stable.axmodel filter=lfs diff=lfs merge=lfs -text
38
  model/colorize_artistic.axmodel filter=lfs diff=lfs merge=lfs -text
 
 
36
  image/sr_colorize.jpg filter=lfs diff=lfs merge=lfs -text
37
  model/colorize_stable.axmodel filter=lfs diff=lfs merge=lfs -text
38
  model/colorize_artistic.axmodel filter=lfs diff=lfs merge=lfs -text
39
+ assert/gradio_demo.JPG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -56,6 +56,25 @@ Input Data:
56
  | `-- 1850Geography.jpg
57
  ```
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
60
 
61
  ```
 
56
  | `-- 1850Geography.jpg
57
  ```
58
 
59
+ #### Inference with M.2 Accelerator card
60
+ ```
61
+ $python3 gradio_demo.py
62
+ [INFO] Available providers: ['AXCLRTExecutionProvider']
63
+ [INFO] Using provider: AXCLRTExecutionProvider
64
+ [INFO] SOC Name: AX650N
65
+ [INFO] VNPU type: VNPUType.DISABLED
66
+ [INFO] Compiler version: 5.0-patch1 2295293f
67
+ [INFO] Using provider: AXCLRTExecutionProvider
68
+ [INFO] SOC Name: AX650N
69
+ [INFO] VNPU type: VNPUType.DISABLED
70
+ [INFO] Compiler version: 5.0-patch1 2295293f
71
+ * Running on local URL: http://0.0.0.0:7860
72
+ * To create a public link, set `share=True` in `launch()`.
73
+ ```
74
+ Then use the M.2 Accelerator card IP instead of the 0.0.0.0, and use chrome open the URL: http://[your ip]:7860
75
+
76
+ ![gradio demo](./assert/gradio_demo.JPG)
77
+
78
  #### Inference with AX650 Host, such as M4N-Dock(爱芯派Pro)
79
 
80
  ```
assert/gradio_demo.JPG ADDED

Git LFS Details

  • SHA256: 8322f39596ee31bb81d042bb94f436afb6537199d64483848642c953ebb161a5
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
python/gradio_demo.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ from PIL import Image, ImageEnhance
5
+ import numpy as np
6
+ import axengine as axe
7
+ import cv2
8
+
9
+ # ==============================
10
+ # 模拟上色函数(请替换为你的实际模型)
11
+ # ==============================
12
+ def init_DeOldifymodel(DeOldifyStable_path="../model/colorize_stable.axmodel",
13
+ DeOldifyArtistic_path="../model/colorize_artistic.axmodel"):
14
+
15
+ DeOldifyStable_session = axe.InferenceSession(DeOldifyStable_path)
16
+ DeOldifyArtistic_session = axe.InferenceSession(DeOldifyArtistic_path)
17
+
18
+ return [DeOldifyStable_session, DeOldifyArtistic_session]
19
+
20
+ DeOldify_sessions=init_DeOldifymodel()
21
+
22
+ def from_numpy(x):
23
+ return x if isinstance(x, np.ndarray) else np.array(x)
24
+
25
+ def post_process(raw_color, orig):
26
+ color_np = np.asarray(raw_color)
27
+ orig_np = np.asarray(orig)
28
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)
29
+ # do a black and white transform first to get better luminance values
30
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)
31
+ hires = np.copy(orig_yuv)
32
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
33
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)
34
+ return final
35
+
36
+ def colorize_with_model(img_path, session):
37
+ output_names = [x.name for x in session.get_outputs()]
38
+ input_name = session.get_inputs()[0].name
39
+
40
+ ori_image = cv2.imread(img_path)
41
+ h, w = ori_image.shape[:2]
42
+ image = cv2.resize(ori_image, (512, 512))
43
+ image = (image[..., ::-1] /255.0).astype(np.float32)
44
+
45
+ mean = [0.485, 0.456, 0.406]
46
+ std = [0.229, 0.224, 0.225]
47
+ image = ((image - mean) / std).astype(np.float32)
48
+
49
+ #image = (image /1.0).astype(np.float32)
50
+ image = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
51
+
52
+ # Use the model to generate super-resolved images
53
+ sr = session.run(output_names, {input_name: image})
54
+
55
+ if isinstance(sr, (list, tuple)):
56
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
57
+ else:
58
+ sr = from_numpy(sr)
59
+
60
+ #sr_y_image = imgproc.array_to_image(sr)
61
+ sr = np.transpose(sr.squeeze(0), (1,2,0))
62
+ sr = (sr*std + mean).astype(np.float32)
63
+
64
+ # Save image
65
+ ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
66
+ ndarr = cv2.resize(ndarr[..., ::-1], (w, h))
67
+ out_image = post_process(ndarr, ori_image)
68
+
69
+ return out_image
70
+
71
+ def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
72
+ if not input_img_path:
73
+ raise gr.Error("未上传图片")
74
+
75
+ # 加载图像
76
+ progress(0.3, desc="加载图像...")
77
+
78
+ # 根据模型选择调用不同函数
79
+ if model_name == "colorize_stable":
80
+ session = DeOldify_sessions[0]
81
+ else:
82
+ session = DeOldify_sessions[1]
83
+ out = colorize_with_model(input_img_path, session)
84
+
85
+ progress(0.9, desc="保存结果...")
86
+
87
+ # 保存到临时文件
88
+ output_path = os.path.join(tempfile.gettempdir(), "colorized_output.jpg")
89
+ cv2.imwrite(output_path, out)
90
+
91
+ progress(1.0, desc="完成!")
92
+ return output_path
93
+
94
+
95
+ # ==============================
96
+ # Gradio 界面
97
+ # ==============================
98
+ custom_css = """
99
+ body, .gradio-container {
100
+ font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
101
+ }
102
+ .model-buttons .wrap {
103
+ display: flex;
104
+ gap: 10px;
105
+ }
106
+ .model-buttons .wrap label {
107
+ background-color: #f0f0f0;
108
+ padding: 10px 20px;
109
+ border-radius: 8px;
110
+ cursor: pointer;
111
+ text-align: center;
112
+ font-weight: 600;
113
+ border: 2px solid transparent;
114
+ flex: 1;
115
+ }
116
+ .model-buttons .wrap label:hover {
117
+ background-color: #e0e0e0;
118
+ }
119
+ .model-buttons .wrap input[type="radio"]:checked + label {
120
+ background-color: #4CAF50;
121
+ color: white;
122
+ border-color: #45a049;
123
+ }
124
+ """
125
+
126
+ with gr.Blocks(title="AI 图片上色工具") as demo:
127
+ gr.Markdown("## 🎨 AI 黑白图片自动上色演示")
128
+
129
+ with gr.Row(equal_height=True):
130
+ # 左侧:输入区
131
+ with gr.Column(scale=1, min_width=300):
132
+ gr.Markdown("### 📤 输入")
133
+ input_image = gr.Image(
134
+ type="filepath",
135
+ label="上传黑白/灰度图片",
136
+ sources=["upload"],
137
+ height=300
138
+ )
139
+
140
+ gr.Markdown("### 🔧 选择上色模型")
141
+ model_choice = gr.Radio(
142
+ choices=["colorize_stable", "colorize_artistic"],
143
+ value="colorize_stable",
144
+ label=None,
145
+ elem_classes="model-buttons"
146
+ )
147
+
148
+ run_btn = gr.Button("🚀 开始上色", variant="primary")
149
+
150
+ # 右侧:输出区
151
+ with gr.Column(scale=1, min_width=600):
152
+ gr.Markdown("### 🖼️ 上色结果")
153
+ output_image = gr.Image(
154
+ label="上色后图片",
155
+ interactive=False,
156
+ height=600
157
+ )
158
+ download_btn = gr.File(label="📥 下载上色图片")
159
+
160
+ # 绑定事件
161
+ def on_colorize(img_path, model, progress=gr.Progress()):
162
+ if img_path is None:
163
+ raise gr.Error("请先上传图片!")
164
+ try:
165
+ result_path = colorize_image(img_path, model, progress=progress)
166
+ return result_path, result_path
167
+ except Exception as e:
168
+ raise gr.Error(f"处理失败: {str(e)}")
169
+
170
+ run_btn.click(
171
+ fn=on_colorize,
172
+ inputs=[input_image, model_choice],
173
+ outputs=[output_image, download_btn]
174
+ )
175
+
176
+ # 启动
177
+ if __name__ == "__main__":
178
+ demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css)