File size: 6,624 Bytes
481945d 9f8a5b5 481945d 9f8a5b5 481945d 9f8a5b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
import os
import tempfile
from PIL import Image, ImageEnhance
import numpy as np
import axengine as axe
import cv2
import socket
# ==============================
# 模拟上色函数(请替换为你的实际模型)
# ==============================
def init_DeOldifymodel(DeOldifyStable_path="../model/colorize_stable.axmodel",
DeOldifyArtistic_path="../model/colorize_artistic.axmodel"):
DeOldifyStable_session = axe.InferenceSession(DeOldifyStable_path)
DeOldifyArtistic_session = axe.InferenceSession(DeOldifyArtistic_path)
return [DeOldifyStable_session, DeOldifyArtistic_session]
DeOldify_sessions=init_DeOldifymodel()
def from_numpy(x):
return x if isinstance(x, np.ndarray) else np.array(x)
def post_process(raw_color, orig):
color_np = np.asarray(raw_color)
orig_np = np.asarray(orig)
color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)
# do a black and white transform first to get better luminance values
orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)
hires = np.copy(orig_yuv)
hires[:, :, 1:3] = color_yuv[:, :, 1:3]
final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)
return final
def colorize_with_model(img_path, session):
output_names = [x.name for x in session.get_outputs()]
input_name = session.get_inputs()[0].name
ori_image = cv2.imread(img_path)
h, w = ori_image.shape[:2]
image = cv2.resize(ori_image, (512, 512))
image = (image[..., ::-1] /255.0).astype(np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image = ((image - mean) / std).astype(np.float32)
#image = (image /1.0).astype(np.float32)
image = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
# Use the model to generate super-resolved images
sr = session.run(output_names, {input_name: image})
if isinstance(sr, (list, tuple)):
sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
else:
sr = from_numpy(sr)
#sr_y_image = imgproc.array_to_image(sr)
sr = np.transpose(sr.squeeze(0), (1,2,0))
sr = (sr*std + mean).astype(np.float32)
# Save image
ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
ndarr = cv2.resize(ndarr[..., ::-1], (w, h))
out_image = post_process(ndarr, ori_image)
return out_image
def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
if not input_img_path:
raise gr.Error("未上传图片")
# 加载图像
progress(0.3, desc="加载图像...")
# 根据模型选择调用不同函数
if model_name == "colorize_stable":
session = DeOldify_sessions[0]
else:
session = DeOldify_sessions[1]
out = colorize_with_model(input_img_path, session)
progress(0.9, desc="保存结果...")
# 保存到临时文件
output_path = os.path.join(tempfile.gettempdir(), "colorized_output.jpg")
cv2.imwrite(output_path, out)
progress(1.0, desc="完成!")
return output_path
# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
display: flex;
gap: 10px;
}
.model-buttons .wrap label {
background-color: #f0f0f0;
padding: 10px 20px;
border-radius: 8px;
cursor: pointer;
text-align: center;
font-weight: 600;
border: 2px solid transparent;
flex: 1;
}
.model-buttons .wrap label:hover {
background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
background-color: #4CAF50;
color: white;
border-color: #45a049;
}
"""
with gr.Blocks(title="AI 图片上色工具") as demo:
gr.Markdown("## 🎨 AI 黑白图片自动上色演示")
with gr.Row(equal_height=True):
# 左侧:输入区
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 📤 输入")
input_image = gr.Image(
type="filepath",
label="上传黑白/灰度图片",
sources=["upload"],
height=300
)
gr.Markdown("### 🔧 选择上色模型")
model_choice = gr.Radio(
choices=["colorize_stable", "colorize_artistic"],
value="colorize_stable",
label=None,
elem_classes="model-buttons"
)
run_btn = gr.Button("🚀 开始上色", variant="primary")
# 右侧:输出区
with gr.Column(scale=1, min_width=600):
gr.Markdown("### 🖼️ 上色结果")
output_image = gr.Image(
label="上色后图片",
interactive=False,
height=600
)
download_btn = gr.File(label="📥 下载上色图片")
# 绑定事件
def on_colorize(img_path, model, progress=gr.Progress()):
if img_path is None:
raise gr.Error("请先上传图片!")
try:
result_path = colorize_image(img_path, model, progress=progress)
return result_path, result_path
except Exception as e:
raise gr.Error(f"处理失败: {str(e)}")
run_btn.click(
fn=on_colorize,
inputs=[input_image, model_choice],
outputs=[output_image, download_btn]
)
def get_local_ip():
"""获取本机局域网IP地址"""
try:
# 创建一个UDP连接(不会真正发送数据)
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80)) # 连接到公共DNS(Google)
ip = s.getsockname()[0]
return ip
except Exception:
# 回退到 localhost
return "127.0.0.1"
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())
server_port = 7860
server_name = "0.0.0.0"
# 获取本机IP
local_ip = get_local_ip()
# 打印可点击的URL(大多数终端支持点击)
print("\n" + "="*50)
print("🌐 AI图片上色 Web UI 已启动!")
print(f"🔗 本地访问: http://127.0.0.1:{server_port}")
if local_ip != "127.0.0.1":
print(f"🔗 局域网访问: http://{local_ip}:{server_port}")
print("="*50 + "\n")
# 启动Gradio应用
demo.launch(
server_name=server_name,
server_port=server_port,
theme=gr.themes.Soft()
) |