Sebastiankay's picture
Update app.py
f4ba661 verified
import os
import sys
import json
import re
import time
import random
import requests
import gradio as gr
import spaces
import base64
from datetime import datetime
from pathlib import Path
from huggingface_hub import login, hf_hub_download
from PIL import Image
from io import BytesIO
sys.path.append("CodeFormer")
import cv2
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.registry import ARCH_REGISTRY
from modules.weights_downloads import download_weights
# Get torch device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define paths using pathlib.Path for consistency
BASE_DIR = Path(__file__).resolve().parent
RES = BASE_DIR / "_res"
ASSETS = RES / "assets"
EXAMPLES = BASE_DIR / "examples"
IMAGE_CACHE = BASE_DIR / "image_cache"
# Ensure the image cache directory exists
IMAGE_CACHE.mkdir(exist_ok=True)
# Set static paths for Gradio
gr.set_static_paths(paths=[RES, IMAGE_CACHE, ASSETS])
# Define paths to your custom CSS and JS files
custom_css_path = RES / "_custom.css"
custom_js_path = RES / "_custom.js"
# Read the content of the CSS and JS files
with open(custom_css_path, "r") as f:
custom_css = f.read()
with open(custom_js_path, "r") as f:
custom_js = f.read()
custom_head = f"""
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.9.0/css/all.min.css"/>
<script src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.9.0/js/all.min.js"></script>
<script src="https://unpkg.com/@dotlottie/player-component@latest/dist/dotlottie-player.mjs" type="module"></script>
"""
title = "Fotorestauration Verbesserung & Upscaling by CodeFormer"
title_html = """
<h1>Fotorestauration</h1>
<h3>Verbesserung & Upscaling <span>by CodeFormer</span></h3>
"""
theme = gr.themes.Soft(
primary_hue="purple",
radius_size="sm",
neutral_hue=gr.themes.Color(c100="#a6adc8", c200="#9399b2", c300="#7f849c", c400="#6c7086", c50="#cdd6f4", c500="#585b70", c600="#45475a", c700="#313244", c800="#1e1e2e", c900="#181825", c950="#11111b"),
)
os.system("pip freeze")
download_weights()
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# set enhancer with RealESRGAN
def set_realesrgan():
half = True if torch.cuda.is_available() else False
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=400,
tile_pad=40,
pre_pad=0,
half=half,
)
return upsampler
upsampler = set_realesrgan()
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
checkpoint = torch.load(ckpt_path)["params_ema"]
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()
os.makedirs("output", exist_ok=True)
@spaces.GPU()
def inference(image, inf_options, upscale, codeformer_fidelity):
"""Run a single prediction on the model"""
try:
only_center_face = False
draw_box = False
detection_model = "retinaface_resnet50"
# "Gesicht ausrichten", "Hintergrund verbessern", "Gesicht Hochskalieren"
print("Inp:", image, inf_options, upscale, codeformer_fidelity)
face_align = False if "Gesicht ausrichten" not in inf_options else True
background_enhance = False if "Hintergrund verbessern" not in inf_options else True
face_upsample = face_upsample if "Gesicht Hochskalieren" not in inf_options else True
upscale = upscale if (upscale is not None and upscale > 0) else 2
has_aligned = not face_align
upscale = 1 if has_aligned else upscale
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
print("\timage size:", img.shape)
upscale = int(upscale) # convert type to int
if upscale > 4: # avoid memory exceeded due to too large upscale
upscale = 4
if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution
upscale = 2
if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
upscale = 1
background_enhance = False
face_upsample = False
face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=detection_model,
save_ext="png",
use_parse=True,
device=device,
)
bg_upsampler = upsampler if background_enhance else None
face_upsampler = upsampler if face_upsample else None
if has_aligned:
# the input faces are already cropped and aligned
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=5)
if face_helper.is_gray:
print("\tgrayscale input: True")
face_helper.cropped_faces = [img]
else:
face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
print(f"\tdetect {num_det_faces} faces")
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"Failed inference for CodeFormer: {error}")
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face)
# paste_back
if not has_aligned:
# upsample the background
if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
else:
bg_img = None
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if face_upsample and face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=draw_box,
face_upsampler=face_upsampler,
)
else:
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
else:
restored_img = restored_face
# save restored img
save_path = f"output/out.png"
imwrite(restored_img, str(save_path))
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
restored_img_pil = Image.fromarray(restored_img)
image_pil = Image.open(image)
return (image_pil, restored_img_pil), image_pil
except Exception as error:
print("Global exception", error)
return None
def load_examles(image, image_ready):
return (image, image_ready)
with gr.Blocks(theme=theme, title=title, css=custom_css, js=custom_js, head=custom_head) as demo_photo_enhance:
with gr.Row(elem_classes="row-header"):
gr.HTML(
f"""
<div class="md-header-wrapper">
{title_html}
<p>Restauriere Gesichter in Fotos die verwackelt sind oder nicht im Focus.<br/>
oder rekonstruiere Gesichter in Fotos bei denen bis zu 50% fehlen.</p>
<p><span style="font-weight: 600">LG Sebastian</span> <img id="wink" src="gradio_api/file=_res/wink.png" width="20"> gib dem Space gerne ein <img id="heart" src="gradio_api/file=_res/heart.png" width="20"> </p>
</div>
""",
elem_classes="md-header",
)
with gr.Row(elem_classes="row-main"):
with gr.Column(scale=3):
inp_image = gr.Image(type="filepath", label="Dein Bild", interactive=True, elem_classes="input-image", show_download_button=False, height=558)
run_btn = gr.Button("Los", variant="primary", elem_id="run_btn", elem_classes="run-btn")
inp_factor = gr.Slider(0, 1, value=0.5, step=0.01, label="Verbesserungsfaktor", info="zu 0 verstärkt die Ausgabe, zu 1 erhält die Identität")
with gr.Accordion("Erweiterte Optionen", open=False):
inf_options = gr.CheckboxGroup(
[
"Gesicht ausrichten",
"Hintergrund verbessern",
"Gesicht Hochskalieren",
],
value=["Gesicht ausrichten", "Hintergrund verbessern", "Gesicht Hochskalieren"],
label="Optionen",
info="Aktiviere oder Deaktiviere die gewünschten Funktionen",
interactive=True,
elem_classes="inp-options",
)
inp_scale = gr.Slider(0, 4, value=2, step=1, label="Foto Hochskalieren", info="Du kannst das Foto bis zum Faktor 4 hochskalieren")
example_output_image = gr.Image(type="filepath", label="Ergebnis", visible=False, interactive=False)
example = gr.Examples(
examples=[
[os.path.join(EXAMPLES, "1.png"), ["Hintergrund verbessern", "Gesicht Hochskalieren"], [2], [0.7], os.path.join(EXAMPLES, "1_ready.png")],
],
inputs=[inp_image, inf_options, inp_scale, inp_factor, example_output_image],
elem_id="examples",
label="Beispiele",
cache_examples=False,
run_on_click=False,
)
with gr.Column(scale=5):
# output_image = gr.Image(type="numpy", label="Ergebnis")
# output_image = ImageSlider(type="numpy", label="Ergebnis")
output_image = gr.ImageSlider(label="Vorher / Nachher", type="pil", interactive=False, elem_classes="output-slider", show_download_button=False, height=800)
hidden_output_image = gr.Image(label="Output image", show_label=False, visible=False, type="pil", format="png", show_download_button=False, show_share_button=False, interactive=False)
with gr.Row():
output_image_dl_btn_webp = gr.DownloadButton(label="Download als WEBP", visible=False)
output_image_dl_btn_png = gr.DownloadButton(label="Download als PNG", visible=False)
output_image_dl_btn_jpg = gr.DownloadButton(label="Download als JPG", visible=False)
run_btn.click(fn=lambda: {"elem_classes":"run-btn run-btn-running", "interactive": False, "__type__": "update"}, outputs=[run_btn]).then(fn=lambda: {"value": None, "__type__": "update"}, outputs=[output_image]).then(fn=inference, inputs=[inp_image, inf_options, inp_scale, inp_factor], outputs=[output_image, hidden_output_image], scroll_to_output=True, api_name="fotoRestaurationInference").then(fn=lambda: {"elem_classes":"run-btn", "interactive": True, "__type__": "update"}, outputs=[run_btn])
example_output_image.change(fn=load_examles, inputs=[inp_image, example_output_image], outputs=[output_image])
def create_dl_button(image):
if not image:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
filename_webp = IMAGE_CACHE / timestamp + ".webp"
image.save(filename_webp, "webp")
filename_png = IMAGE_CACHE / timestamp + ".png"
image.save(IMAGE_CACHE / filename_png, "png")
filename_jpg = IMAGE_CACHE / timestamp + ".jpg"
image.save(filename_jpg, "jpeg")
print(f"\n\nDEBUG created download buttons:\n{IMAGE_CACHE / filename_png}\n{IMAGE_CACHE / filename_jpg}\n\n")
return gr.update(visible=True, value=filename_webp), gr.update(visible=True, value=filename_png), gr.update(visible=True, value=filename_jpg)
# hidden_output_image.change(create_dl_button, inputs=[hidden_output_image], outputs=[output_image_dl_btn_png])
hidden_output_image.change(create_dl_button, inputs=[hidden_output_image], outputs=[output_image_dl_btn_webp, output_image_dl_btn_png, output_image_dl_btn_jpg])
if __name__ == "__main__":
demo_photo_enhance.launch(show_api=True)