Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,9 @@ import cv2
|
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import urllib.request
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Function to download a file from a URL
|
| 11 |
def download_file(url, dest):
|
|
@@ -14,65 +17,56 @@ def download_file(url, dest):
|
|
| 14 |
urllib.request.urlretrieve(url, dest)
|
| 15 |
print(f"Downloaded {dest}")
|
| 16 |
|
| 17 |
-
# Download pretrained
|
| 18 |
def setup_environment():
|
| 19 |
# Download CodeFormer pretrained model
|
| 20 |
model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
| 21 |
-
model_path = "weights/codeformer.pth"
|
| 22 |
download_file(model_url, model_path)
|
| 23 |
|
| 24 |
-
# Download
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
download_file(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Load CodeFormer model
|
| 30 |
def load_codeformer():
|
| 31 |
setup_environment()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
net.eval()
|
| 38 |
-
return net
|
| 39 |
-
|
| 40 |
-
# Image processing utilities (mimicking basicsr.utils)
|
| 41 |
-
def img2tensor(img, bgr2rgb=True, float32=True):
|
| 42 |
-
if bgr2rgb:
|
| 43 |
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 44 |
-
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 45 |
-
if float32:
|
| 46 |
-
img = img / 255.0
|
| 47 |
-
return img
|
| 48 |
-
|
| 49 |
-
def tensor2img(tensor, rgb2bgr=True, min_max=(-1, 1)):
|
| 50 |
-
tensor = tensor.squeeze().float().cpu().clamp_(*min_max)
|
| 51 |
-
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
|
| 52 |
-
img = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
|
| 53 |
-
if rgb2bgr:
|
| 54 |
-
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 55 |
-
return img
|
| 56 |
|
| 57 |
# Inference function
|
| 58 |
-
def enhance_image(
|
| 59 |
-
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 60 |
-
|
| 61 |
-
# Load model
|
| 62 |
-
net = load_codeformer()
|
| 63 |
-
|
| 64 |
# Convert PIL image to OpenCV format
|
| 65 |
-
img = np.array(
|
| 66 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 67 |
-
|
| 68 |
# Initialize face helper
|
| 69 |
-
face_helper = FaceRestoreHelper(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
face_helper.clean_all()
|
| 71 |
face_helper.read_image(img)
|
| 72 |
face_helper.get_face_landmarks_5()
|
| 73 |
face_helper.align_warp_face()
|
| 74 |
-
|
| 75 |
-
#
|
|
|
|
|
|
|
|
|
|
| 76 |
for cropped_face in face_helper.cropped_faces:
|
| 77 |
cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
|
| 78 |
with torch.no_grad():
|
|
@@ -80,11 +74,21 @@ def enhance_image(image, fidelity_weight=0.5):
|
|
| 80 |
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
| 81 |
restored_face = restored_face.astype('uint8')
|
| 82 |
face_helper.add_restored_face(restored_face)
|
| 83 |
-
|
| 84 |
-
# Get
|
| 85 |
face_helper.get_inverse_affine(None)
|
| 86 |
restored_img = face_helper.paste_faces_to_input_image()
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
# Convert back to PIL for Gradio
|
| 89 |
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
|
| 90 |
return Image.fromarray(restored_img)
|
|
@@ -93,17 +97,19 @@ def enhance_image(image, fidelity_weight=0.5):
|
|
| 93 |
with gr.Blocks() as demo:
|
| 94 |
gr.Markdown("# CodeFormer Face Restoration (CPU)")
|
| 95 |
gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
|
| 96 |
-
|
| 97 |
with gr.Row():
|
| 98 |
input_image = gr.Image(type="pil", label="Input Image")
|
| 99 |
output_image = gr.Image(type="pil", label="Enhanced Image")
|
| 100 |
-
|
| 101 |
-
fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.
|
|
|
|
|
|
|
| 102 |
submit_btn = gr.Button("Enhance")
|
| 103 |
-
|
| 104 |
submit_btn.click(
|
| 105 |
fn=enhance_image,
|
| 106 |
-
inputs=[input_image, fidelity_slider],
|
| 107 |
outputs=output_image
|
| 108 |
)
|
| 109 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
import urllib.request
|
| 9 |
+
from basicsr.utils import img2tensor, tensor2img
|
| 10 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
| 11 |
+
from codeformer_arch import CodeFormer
|
| 12 |
|
| 13 |
# Function to download a file from a URL
|
| 14 |
def download_file(url, dest):
|
|
|
|
| 17 |
urllib.request.urlretrieve(url, dest)
|
| 18 |
print(f"Downloaded {dest}")
|
| 19 |
|
| 20 |
+
# Download pretrained models
|
| 21 |
def setup_environment():
|
| 22 |
# Download CodeFormer pretrained model
|
| 23 |
model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
| 24 |
+
model_path = "weights/CodeFormer/codeformer.pth"
|
| 25 |
download_file(model_url, model_path)
|
| 26 |
|
| 27 |
+
# Download facelib model (for face detection)
|
| 28 |
+
facelib_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/facelib.pth"
|
| 29 |
+
facelib_path = "weights/facelib.pth"
|
| 30 |
+
download_file(facelib_url, facelib_path)
|
| 31 |
+
|
| 32 |
+
# Download Real-ESRGAN model for background upsampling (optional)
|
| 33 |
+
realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/RealESRGAN_x4plus.pth"
|
| 34 |
+
realesrgan_path = "weights/RealESRGAN_x4plus.pth"
|
| 35 |
+
download_file(realesrgan_url, realesrgan_path)
|
| 36 |
|
| 37 |
# Load CodeFormer model
|
| 38 |
def load_codeformer():
|
| 39 |
setup_environment()
|
| 40 |
+
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256'])
|
| 41 |
+
model.load_state_dict(torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu'))
|
| 42 |
+
model.eval()
|
| 43 |
+
model = model.to('cpu') # Force CPU
|
| 44 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Inference function
|
| 47 |
+
def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Convert PIL image to OpenCV format
|
| 49 |
+
img = np.array(input_image)
|
| 50 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 51 |
+
|
| 52 |
# Initialize face helper
|
| 53 |
+
face_helper = FaceRestoreHelper(
|
| 54 |
+
upscale_factor=1 if not face_upsample else 2,
|
| 55 |
+
face_size=512,
|
| 56 |
+
crop_ratio=(1, 1),
|
| 57 |
+
det_model='retinaface_resnet50',
|
| 58 |
+
save_ext='png',
|
| 59 |
+
device='cpu'
|
| 60 |
+
)
|
| 61 |
face_helper.clean_all()
|
| 62 |
face_helper.read_image(img)
|
| 63 |
face_helper.get_face_landmarks_5()
|
| 64 |
face_helper.align_warp_face()
|
| 65 |
+
|
| 66 |
+
# Load CodeFormer model
|
| 67 |
+
net = load_codeformer()
|
| 68 |
+
|
| 69 |
+
# Enhance face
|
| 70 |
for cropped_face in face_helper.cropped_faces:
|
| 71 |
cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
|
| 72 |
with torch.no_grad():
|
|
|
|
| 74 |
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
| 75 |
restored_face = restored_face.astype('uint8')
|
| 76 |
face_helper.add_restored_face(restored_face)
|
| 77 |
+
|
| 78 |
+
# Get restored image
|
| 79 |
face_helper.get_inverse_affine(None)
|
| 80 |
restored_img = face_helper.paste_faces_to_input_image()
|
| 81 |
+
|
| 82 |
+
# Background enhancement with Real-ESRGAN (optional)
|
| 83 |
+
if background_enhance:
|
| 84 |
+
from realesrgan import RealESRGANer
|
| 85 |
+
upsampler = RealESRGANer(
|
| 86 |
+
scale=4,
|
| 87 |
+
model_path="weights/RealESRGAN_x4plus.pth",
|
| 88 |
+
device='cpu'
|
| 89 |
+
)
|
| 90 |
+
restored_img, _ = upsampler.enhance(restored_img, outscale=4)
|
| 91 |
+
|
| 92 |
# Convert back to PIL for Gradio
|
| 93 |
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
|
| 94 |
return Image.fromarray(restored_img)
|
|
|
|
| 97 |
with gr.Blocks() as demo:
|
| 98 |
gr.Markdown("# CodeFormer Face Restoration (CPU)")
|
| 99 |
gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
|
| 100 |
+
|
| 101 |
with gr.Row():
|
| 102 |
input_image = gr.Image(type="pil", label="Input Image")
|
| 103 |
output_image = gr.Image(type="pil", label="Enhanced Image")
|
| 104 |
+
|
| 105 |
+
fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)")
|
| 106 |
+
background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True)
|
| 107 |
+
face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False)
|
| 108 |
submit_btn = gr.Button("Enhance")
|
| 109 |
+
|
| 110 |
submit_btn.click(
|
| 111 |
fn=enhance_image,
|
| 112 |
+
inputs=[input_image, fidelity_slider, background_enhance, face_upsample],
|
| 113 |
outputs=output_image
|
| 114 |
)
|
| 115 |
|