File size: 6,624 Bytes
481945d
 
 
 
 
 
 
9f8a5b5
481945d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f8a5b5
 
 
 
 
 
 
 
 
 
 
 
481945d
 
9f8a5b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import os
import tempfile
from PIL import Image, ImageEnhance
import numpy as np
import axengine as axe
import cv2
import socket

# ==============================
# 模拟上色函数(请替换为你的实际模型)
# ==============================
def init_DeOldifymodel(DeOldifyStable_path="../model/colorize_stable.axmodel", 
                 DeOldifyArtistic_path="../model/colorize_artistic.axmodel"):
    
    DeOldifyStable_session = axe.InferenceSession(DeOldifyStable_path)
    DeOldifyArtistic_session = axe.InferenceSession(DeOldifyArtistic_path)

    return [DeOldifyStable_session, DeOldifyArtistic_session]

DeOldify_sessions=init_DeOldifymodel()

def from_numpy(x):
    return x if isinstance(x, np.ndarray) else np.array(x)
    
def post_process(raw_color, orig):
    color_np = np.asarray(raw_color)
    orig_np = np.asarray(orig)
    color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)
    # do a black and white transform first to get better luminance values
    orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)
    hires = np.copy(orig_yuv)
    hires[:, :, 1:3] = color_yuv[:, :, 1:3]
    final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)
    return final

def colorize_with_model(img_path, session):
    output_names = [x.name for x in session.get_outputs()]
    input_name = session.get_inputs()[0].name

    ori_image = cv2.imread(img_path)
    h, w = ori_image.shape[:2]
    image = cv2.resize(ori_image, (512, 512))
    image = (image[..., ::-1] /255.0).astype(np.float32)
    
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image = ((image - mean) / std).astype(np.float32)

    #image = (image /1.0).astype(np.float32)
    image = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
    
    # Use the model to generate super-resolved images
    sr = session.run(output_names, {input_name: image})
    
    if isinstance(sr, (list, tuple)):
        sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
    else:
        sr = from_numpy(sr)

    #sr_y_image = imgproc.array_to_image(sr)
    sr = np.transpose(sr.squeeze(0), (1,2,0))
    sr = (sr*std + mean).astype(np.float32)
    
    # Save image
    ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
    ndarr = cv2.resize(ndarr[..., ::-1], (w, h))
    out_image = post_process(ndarr, ori_image)

    return out_image

def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
    if not input_img_path:
        raise gr.Error("未上传图片")

    # 加载图像
    progress(0.3, desc="加载图像...")

    # 根据模型选择调用不同函数
    if model_name == "colorize_stable":
        session = DeOldify_sessions[0]
    else:
        session = DeOldify_sessions[1]
    out = colorize_with_model(input_img_path, session)

    progress(0.9, desc="保存结果...")

    # 保存到临时文件
    output_path = os.path.join(tempfile.gettempdir(), "colorized_output.jpg")
    cv2.imwrite(output_path, out)

    progress(1.0, desc="完成!")
    return output_path


# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
    font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
    display: flex;
    gap: 10px;
}
.model-buttons .wrap label {
    background-color: #f0f0f0;
    padding: 10px 20px;
    border-radius: 8px;
    cursor: pointer;
    text-align: center;
    font-weight: 600;
    border: 2px solid transparent;
    flex: 1;
}
.model-buttons .wrap label:hover {
    background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
    background-color: #4CAF50;
    color: white;
    border-color: #45a049;
}
"""

with gr.Blocks(title="AI 图片上色工具") as demo:
    gr.Markdown("## 🎨 AI 黑白图片自动上色演示")

    with gr.Row(equal_height=True):
        # 左侧:输入区
        with gr.Column(scale=1, min_width=300):
            gr.Markdown("### 📤 输入")
            input_image = gr.Image(
                type="filepath",
                label="上传黑白/灰度图片",
                sources=["upload"],
                height=300
            )
            
            gr.Markdown("### 🔧 选择上色模型")
            model_choice = gr.Radio(
                choices=["colorize_stable", "colorize_artistic"],
                value="colorize_stable",
                label=None,
                elem_classes="model-buttons"
            )
            
            run_btn = gr.Button("🚀 开始上色", variant="primary")

        # 右侧:输出区
        with gr.Column(scale=1, min_width=600):
            gr.Markdown("### 🖼️ 上色结果")
            output_image = gr.Image(
                label="上色后图片",
                interactive=False,
                height=600
            )
            download_btn = gr.File(label="📥 下载上色图片")

    # 绑定事件
    def on_colorize(img_path, model, progress=gr.Progress()):
        if img_path is None:
            raise gr.Error("请先上传图片!")
        try:
            result_path = colorize_image(img_path, model, progress=progress)
            return result_path, result_path
        except Exception as e:
            raise gr.Error(f"处理失败: {str(e)}")

    run_btn.click(
        fn=on_colorize,
        inputs=[input_image, model_choice],
        outputs=[output_image, download_btn]
    )
def get_local_ip():
    """获取本机局域网IP地址"""
    try:
        # 创建一个UDP连接(不会真正发送数据)
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
            s.connect(("8.8.8.8", 80))  # 连接到公共DNS(Google)
            ip = s.getsockname()[0]
        return ip
    except Exception:
        # 回退到 localhost
        return "127.0.0.1"


if __name__ == "__main__":
    # demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft())

    server_port = 7860
    server_name = "0.0.0.0"
    
    # 获取本机IP
    local_ip = get_local_ip()
    
    # 打印可点击的URL(大多数终端支持点击)
    print("\n" + "="*50)
    print("🌐 AI图片上色 Web UI 已启动!")
    print(f"🔗 本地访问: http://127.0.0.1:{server_port}")
    if local_ip != "127.0.0.1":
        print(f"🔗 局域网访问: http://{local_ip}:{server_port}")
    print("="*50 + "\n")

    # 启动Gradio应用
    demo.launch(
        server_name=server_name,
        server_port=server_port,
        theme=gr.themes.Soft()
    )