CodeFormer / python /gradio_demo.py
wli1995's picture
Upload gradio demo
442b121 verified
raw
history blame
6.62 kB
import gradio as gr
import os
import tempfile
import numpy as np
import axengine as axe
import cv2
from utils.restoration_helper import RestoreHelper
restore_helper = RestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model="../model/yolov5l-face.axmodel",
res_model="../model/codeformer.axmodel",
bg_model="../model/realesrgan-x2.axmodel",
save_ext='png',
use_parse=True
)
def face(img_path, session):
output_names = [x.name for x in session.get_outputs()]
input_name = session.get_inputs()[0].name
ori_image = cv2.imread(img_path)
h, w = ori_image.shape[:2]
image = cv2.resize(ori_image, (512, 512))
image = (image[..., ::-1] /255.0).astype(np.float32)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
image = ((image - mean) / std).astype(np.float32)
#image = (image /1.0).astype(np.float32)
img = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
# Use the model to generate super-resolved images
sr = session.run(output_names, {input_name: img})
#sr_y_image = imgproc.array_to_image(sr)
sr = np.transpose(sr[0].squeeze(0), (1,2,0))
sr = (sr*std + mean).astype(np.float32)
# Save image
ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
out_image = cv2.resize(ndarr[..., ::-1], (w, h))
return out_image
def full_image(img_path, restore_helper=restore_helper):
restore_helper.clean_all()
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
restore_helper.read_image(img)
# get face landmarks for each face
num_det_faces = restore_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5)
# align and warp each face
restore_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(restore_helper.cropped_faces):
# prepare data
cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
cropped_face_t = np.transpose(
np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
(0,3,1,2)
)
#print('cropped_face_t', cropped_face_t.shape)
try:
ort_outs = restore_helper.rs_sessison.run(
restore_helper.rs_output,
{restore_helper.rs_input: cropped_face_t}
)
restored_face = ort_outs[0]
restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
restored_face = restored_face.astype('uint8')
restore_helper.add_restored_face(restored_face, cropped_face)
# upsample the background
# Now only support RealESRGAN for upsampling background
bg_img = restore_helper.background_upsampling(img)
restore_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
return restored_img
def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
if not input_img_path:
raise gr.Error("未上传图片")
# 加载图像
progress(0.3, desc="加载图像...")
# 根据模型选择调用不同函数
if model_name == "Face":
out = face(input_img_path, session=restore_helper.rs_sessison)
else:
out = full_image(input_img_path, restore_helper=restore_helper)
progress(0.9, desc="保存结果...")
# 保存到临时文件
output_path = os.path.join(tempfile.gettempdir(), "restore_output.jpg")
cv2.imwrite(output_path, out)
progress(1.0, desc="完成!")
return output_path
# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
display: flex;
gap: 10px;
}
.model-buttons .wrap label {
background-color: #f0f0f0;
padding: 10px 20px;
border-radius: 8px;
cursor: pointer;
text-align: center;
font-weight: 600;
border: 2px solid transparent;
flex: 1;
}
.model-buttons .wrap label:hover {
background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
background-color: #4CAF50;
color: white;
border-color: #45a049;
}
"""
with gr.Blocks(title="人脸修复工具") as demo:
gr.Markdown("## 🎨 人脸修复演示DEMO")
with gr.Row(equal_height=True):
# 左侧:输入区
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 📤 输入")
input_image = gr.Image(
type="filepath",
label="上传图片",
sources=["upload"],
height=300
)
gr.Markdown("### 🔧 选择修复模式")
model_choice = gr.Radio(
choices=["Face", "Full image"],
value="Face",
label=None,
elem_classes="model-buttons"
)
run_btn = gr.Button("🚀 开始修复", variant="primary")
# 右侧:输出区
with gr.Column(scale=1, min_width=600):
gr.Markdown("### 🖼️ 修复结果")
output_image = gr.Image(
label="修复后图片",
interactive=False,
height=600
)
download_btn = gr.File(label="📥 下载修复图片")
# 绑定事件
def on_colorize(img_path, model, progress=gr.Progress()):
if img_path is None:
raise gr.Error("请先上传图片!")
try:
result_path = colorize_image(img_path, model, progress=progress)
return result_path, result_path
except Exception as e:
raise gr.Error(f"处理失败: {str(e)}")
run_btn.click(
fn=on_colorize,
inputs=[input_image, model_choice],
outputs=[output_image, download_btn]
)
# 启动
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css)