lucky0146 commited on
Commit
bbdeecf
·
verified ·
1 Parent(s): 570819f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -49
app.py CHANGED
@@ -6,6 +6,9 @@ import cv2
6
  import numpy as np
7
  from PIL import Image
8
  import urllib.request
 
 
 
9
 
10
  # Function to download a file from a URL
11
  def download_file(url, dest):
@@ -14,65 +17,56 @@ def download_file(url, dest):
14
  urllib.request.urlretrieve(url, dest)
15
  print(f"Downloaded {dest}")
16
 
17
- # Download pretrained model and necessary files
18
  def setup_environment():
19
  # Download CodeFormer pretrained model
20
  model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
21
- model_path = "weights/codeformer.pth"
22
  download_file(model_url, model_path)
23
 
24
- # Download facexlib detection models
25
- retinaface_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
26
- retinaface_path = "weights/detection_Resnet50_Final.pth"
27
- download_file(retinaface_url, retinaface_path)
 
 
 
 
 
28
 
29
  # Load CodeFormer model
30
  def load_codeformer():
31
  setup_environment()
32
- from codeformer_arch import CodeFormer
33
- model_path = "weights/codeformer.pth"
34
- net = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256']).to('cpu')
35
- checkpoint = torch.load(model_path, map_location='cpu')
36
- net.load_state_dict(checkpoint)
37
- net.eval()
38
- return net
39
-
40
- # Image processing utilities (mimicking basicsr.utils)
41
- def img2tensor(img, bgr2rgb=True, float32=True):
42
- if bgr2rgb:
43
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
44
- img = torch.from_numpy(img.transpose(2, 0, 1)).float()
45
- if float32:
46
- img = img / 255.0
47
- return img
48
-
49
- def tensor2img(tensor, rgb2bgr=True, min_max=(-1, 1)):
50
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max)
51
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
52
- img = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
53
- if rgb2bgr:
54
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
55
- return img
56
 
57
  # Inference function
58
- def enhance_image(image, fidelity_weight=0.5):
59
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
60
-
61
- # Load model
62
- net = load_codeformer()
63
-
64
  # Convert PIL image to OpenCV format
65
- img = np.array(image)
66
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
67
-
68
  # Initialize face helper
69
- face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', device='cpu')
 
 
 
 
 
 
 
70
  face_helper.clean_all()
71
  face_helper.read_image(img)
72
  face_helper.get_face_landmarks_5()
73
  face_helper.align_warp_face()
74
-
75
- # Enhance face with CodeFormer
 
 
 
76
  for cropped_face in face_helper.cropped_faces:
77
  cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
78
  with torch.no_grad():
@@ -80,11 +74,21 @@ def enhance_image(image, fidelity_weight=0.5):
80
  restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
81
  restored_face = restored_face.astype('uint8')
82
  face_helper.add_restored_face(restored_face)
83
-
84
- # Get final restored image
85
  face_helper.get_inverse_affine(None)
86
  restored_img = face_helper.paste_faces_to_input_image()
87
-
 
 
 
 
 
 
 
 
 
 
88
  # Convert back to PIL for Gradio
89
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
90
  return Image.fromarray(restored_img)
@@ -93,17 +97,19 @@ def enhance_image(image, fidelity_weight=0.5):
93
  with gr.Blocks() as demo:
94
  gr.Markdown("# CodeFormer Face Restoration (CPU)")
95
  gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
96
-
97
  with gr.Row():
98
  input_image = gr.Image(type="pil", label="Input Image")
99
  output_image = gr.Image(type="pil", label="Enhanced Image")
100
-
101
- fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Fidelity Weight (0 = more restoration, 1 = more original)")
 
 
102
  submit_btn = gr.Button("Enhance")
103
-
104
  submit_btn.click(
105
  fn=enhance_image,
106
- inputs=[input_image, fidelity_slider],
107
  outputs=output_image
108
  )
109
 
 
6
  import numpy as np
7
  from PIL import Image
8
  import urllib.request
9
+ from basicsr.utils import img2tensor, tensor2img
10
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
11
+ from codeformer_arch import CodeFormer
12
 
13
  # Function to download a file from a URL
14
  def download_file(url, dest):
 
17
  urllib.request.urlretrieve(url, dest)
18
  print(f"Downloaded {dest}")
19
 
20
+ # Download pretrained models
21
  def setup_environment():
22
  # Download CodeFormer pretrained model
23
  model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
24
+ model_path = "weights/CodeFormer/codeformer.pth"
25
  download_file(model_url, model_path)
26
 
27
+ # Download facelib model (for face detection)
28
+ facelib_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/facelib.pth"
29
+ facelib_path = "weights/facelib.pth"
30
+ download_file(facelib_url, facelib_path)
31
+
32
+ # Download Real-ESRGAN model for background upsampling (optional)
33
+ realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/RealESRGAN_x4plus.pth"
34
+ realesrgan_path = "weights/RealESRGAN_x4plus.pth"
35
+ download_file(realesrgan_url, realesrgan_path)
36
 
37
  # Load CodeFormer model
38
  def load_codeformer():
39
  setup_environment()
40
+ model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256'])
41
+ model.load_state_dict(torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu'))
42
+ model.eval()
43
+ model = model.to('cpu') # Force CPU
44
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Inference function
47
+ def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False):
 
 
 
 
 
48
  # Convert PIL image to OpenCV format
49
+ img = np.array(input_image)
50
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
51
+
52
  # Initialize face helper
53
+ face_helper = FaceRestoreHelper(
54
+ upscale_factor=1 if not face_upsample else 2,
55
+ face_size=512,
56
+ crop_ratio=(1, 1),
57
+ det_model='retinaface_resnet50',
58
+ save_ext='png',
59
+ device='cpu'
60
+ )
61
  face_helper.clean_all()
62
  face_helper.read_image(img)
63
  face_helper.get_face_landmarks_5()
64
  face_helper.align_warp_face()
65
+
66
+ # Load CodeFormer model
67
+ net = load_codeformer()
68
+
69
+ # Enhance face
70
  for cropped_face in face_helper.cropped_faces:
71
  cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
72
  with torch.no_grad():
 
74
  restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
75
  restored_face = restored_face.astype('uint8')
76
  face_helper.add_restored_face(restored_face)
77
+
78
+ # Get restored image
79
  face_helper.get_inverse_affine(None)
80
  restored_img = face_helper.paste_faces_to_input_image()
81
+
82
+ # Background enhancement with Real-ESRGAN (optional)
83
+ if background_enhance:
84
+ from realesrgan import RealESRGANer
85
+ upsampler = RealESRGANer(
86
+ scale=4,
87
+ model_path="weights/RealESRGAN_x4plus.pth",
88
+ device='cpu'
89
+ )
90
+ restored_img, _ = upsampler.enhance(restored_img, outscale=4)
91
+
92
  # Convert back to PIL for Gradio
93
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
94
  return Image.fromarray(restored_img)
 
97
  with gr.Blocks() as demo:
98
  gr.Markdown("# CodeFormer Face Restoration (CPU)")
99
  gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
100
+
101
  with gr.Row():
102
  input_image = gr.Image(type="pil", label="Input Image")
103
  output_image = gr.Image(type="pil", label="Enhanced Image")
104
+
105
+ fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)")
106
+ background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True)
107
+ face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False)
108
  submit_btn = gr.Button("Enhance")
109
+
110
  submit_btn.click(
111
  fn=enhance_image,
112
+ inputs=[input_image, fidelity_slider, background_enhance, face_upsample],
113
  outputs=output_image
114
  )
115