lucky0146 commited on
Commit
fea84be
·
verified ·
1 Parent(s): fa500fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -147
app.py CHANGED
@@ -2,76 +2,85 @@ import os
2
  import sys
3
  import torch
4
  import gradio as gr
5
- from PIL import Image
6
  import numpy as np
7
  import requests
8
- from io import BytesIO
9
- from basicsr.utils import imwrite
10
- from torchvision.transforms.functional import normalize
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Clone the CodeFormer repository if it doesn't exist
13
  if not os.path.exists('CodeFormer'):
14
- !git clone https://github.com/sczhou/CodeFormer.git
15
- !pip install -r CodeFormer/requirements.txt
16
- !pip install basicsr
17
- !pip install facexlib
18
- !pip install gradio>=3.25.0
19
- !pip install realesrgan
20
- !python CodeFormer/basicsr/setup.py develop
21
 
22
- # Add the CodeFormer directory to the system path
23
  sys.path.append('CodeFormer')
24
 
25
- # Import necessary modules from CodeFormer
26
- from CodeFormer.basicsr.archs.codeformer_arch import CodeFormer
27
- from CodeFormer.basicsr.utils.registry import ARCH_REGISTRY
28
- from CodeFormer.facelib.utils.face_restoration_helper import FaceRestoreHelper
29
- from CodeFormer.facelib.detection.retinaface import retinaface
30
 
31
- # Function to download model weights
32
- def download_model_weights():
33
- if not os.path.exists('CodeFormer/weights'):
34
- os.makedirs('CodeFormer/weights/CodeFormer', exist_ok=True)
35
- os.makedirs('CodeFormer/weights/facelib', exist_ok=True)
36
-
37
- # Download CodeFormer weights
38
- codeformer_weight_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
39
- if not os.path.exists(codeformer_weight_path):
40
- print('Downloading CodeFormer weights...')
41
- url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
42
- r = requests.get(url, allow_redirects=True)
43
- with open(codeformer_weight_path, 'wb') as f:
44
- f.write(r.content)
45
-
46
- # Download detection model weights
47
- detection_model_path = 'CodeFormer/weights/facelib/detection_Resnet50_Final.pth'
48
- if not os.path.exists(detection_model_path):
49
- print('Downloading face detection model weights...')
50
- url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth'
51
- r = requests.get(url, allow_redirects=True)
52
- with open(detection_model_path, 'wb') as f:
53
- f.write(r.content)
54
 
55
- # Download parsing model weights
56
- parsing_model_path = 'CodeFormer/weights/facelib/parsing_parsenet.pth'
57
- if not os.path.exists(parsing_model_path):
58
- print('Downloading face parsing model weights...')
59
- url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
60
- r = requests.get(url, allow_redirects=True)
61
- with open(parsing_model_path, 'wb') as f:
62
- f.write(r.content)
 
 
 
 
 
63
 
64
- # Load CodeFormer model
65
- def load_codeformer_model():
66
- # Force to use CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  device = torch.device('cpu')
68
- print(f'Running on device: {device}')
69
-
70
- # Download model weights if they don't exist
71
- download_model_weights()
72
 
73
- # Load CodeFormer model
74
- codeformer_net = ARCH_REGISTRY.get('CodeFormer')(
75
  dim_embd=512,
76
  codebook_size=1024,
77
  n_head=8,
@@ -79,17 +88,18 @@ def load_codeformer_model():
79
  connect_list=['32', '64', '128', '256']
80
  ).to(device)
81
 
 
82
  ckpt_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
83
  checkpoint = torch.load(ckpt_path, map_location=device)
84
 
85
  if 'params_ema' in checkpoint:
86
- codeformer_net.load_state_dict(checkpoint['params_ema'])
87
  else:
88
- codeformer_net.load_state_dict(checkpoint['params'])
89
 
90
- codeformer_net.eval()
91
 
92
- # Setup face restoration helper
93
  face_helper = FaceRestoreHelper(
94
  upscale_factor=1,
95
  face_size=512,
@@ -100,128 +110,110 @@ def load_codeformer_model():
100
  device=device
101
  )
102
 
103
- return codeformer_net, face_helper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Process the image with CodeFormer
106
- def process_image(image, w=0.5, has_aligned=False):
 
 
 
107
  device = torch.device('cpu')
108
- codeformer_net, face_helper = load_codeformer_model()
109
 
110
- # Convert the input image to numpy array
111
  if isinstance(image, Image.Image):
112
  img = np.array(image)
113
  else:
114
  img = image
115
-
 
116
  if has_aligned:
117
- # The input image is already a cropped and aligned face
118
  face_helper.is_gray = len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1)
119
  if face_helper.is_gray:
120
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
121
- # Prepare the face for processing
122
  face_helper.cropped_faces = [img]
123
  else:
124
- # Detect and crop faces from the input image
125
  face_helper.clean_all()
126
  face_helper.read_image(img)
127
  face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
128
  face_helper.align_warp_face()
129
 
130
- # If no face is detected
131
  if len(face_helper.cropped_faces) == 0:
132
- return image, "No face detected. Please try another image."
133
 
134
- # CodeFormer inference
135
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
136
- # Prepare the image for the model
137
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
138
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
139
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
140
 
141
  try:
142
  with torch.no_grad():
143
- output = codeformer_net(cropped_face_t, w=w, adain=True)[0]
144
- # Convert tensor to image
145
  restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
 
 
146
  del output
147
- torch.cuda.empty_cache()
148
- except Exception as error:
149
- print(f'Error: {error}')
150
  restored_face = cropped_face
151
 
152
- # Save the restored face
153
  face_helper.add_restored_face(restored_face)
154
 
155
- # Get the final result
156
  if not has_aligned:
157
- # Paste the restored faces back to the original image
158
  face_helper.get_inverse_affine(None)
159
  restored_img = face_helper.paste_faces_to_input_image()
160
- restored_img = Image.fromarray(restored_img)
161
  else:
162
- restored_img = Image.fromarray(face_helper.restored_faces[0])
163
 
164
- return restored_img, "Face successfully restored."
165
-
166
- # Helper functions for image conversion
167
- def img2tensor(img, bgr2rgb=True, float32=True):
168
- img = img.astype(np.float32) if float32 else img
169
- if bgr2rgb:
170
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
171
- img = torch.from_numpy(img.transpose(2, 0, 1))
172
- return img
173
-
174
- def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
175
- tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
176
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
177
- n_dim = tensor.dim()
178
- if n_dim == 3:
179
- img_np = tensor.numpy()
180
- img_np = img_np.transpose(1, 2, 0)
181
- if rgb2bgr:
182
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
183
- elif n_dim == 2:
184
- img_np = tensor.numpy()
185
- else:
186
- raise TypeError(f'Only support 3D and 2D tensor. But got {n_dim}D tensor.')
187
- if out_type == np.uint8:
188
- img_np = (img_np * 255.0).round().astype(np.uint8)
189
- return img_np
190
-
191
- # Create a Gradio interface for the app
192
- def create_gradio_interface():
193
- with gr.Blocks(title="CodeFormer Face Restoration (CPU Version)") as app:
194
- gr.Markdown("# CodeFormer Face Restoration (CPU Version)")
195
- gr.Markdown("Upload a photo with faces to restore the quality. This model runs on CPU, so it might take a few minutes to process.")
196
-
197
- with gr.Row():
198
- with gr.Column():
199
- input_image = gr.Image(label="Input Image", type="pil")
200
- w_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Fidelity Weight (0: more quality, 1: more identity)")
201
- aligned_checkbox = gr.Checkbox(label="Input is an already aligned face", value=False)
202
- process_button = gr.Button("Restore Face")
203
-
204
- with gr.Column():
205
- output_image = gr.Image(label="Restored Image")
206
- output_text = gr.Textbox(label="Status")
207
-
208
- process_button.click(
209
- fn=process_image,
210
- inputs=[input_image, w_slider, aligned_checkbox],
211
- outputs=[output_image, output_text]
212
- )
213
-
214
- gr.Markdown("Note: Lower fidelity weight (w) values create higher-quality results with more modifications, while higher values preserve more of the original identity.")
215
-
216
- return app
217
-
218
- # Import CV2 only when needed to avoid issues
219
- import cv2
220
 
221
- # Main function
222
- def main():
223
- app = create_gradio_interface()
224
- app.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- if __name__ == "__main__":
227
- main()
 
2
  import sys
3
  import torch
4
  import gradio as gr
 
5
  import numpy as np
6
  import requests
7
+ from PIL import Image
8
+ import cv2
9
+ import subprocess
10
+
11
+ # Function to run shell commands
12
+ def run_command(command):
13
+ process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
14
+ stdout, stderr = process.communicate()
15
+ if process.returncode != 0:
16
+ print(f"Error executing command: {command}")
17
+ print(stderr.decode())
18
+ else:
19
+ print(stdout.decode())
20
+ return process.returncode
21
 
22
+ # Clone repository and install dependencies
23
  if not os.path.exists('CodeFormer'):
24
+ run_command("git clone https://github.com/sczhou/CodeFormer.git")
25
+ run_command("pip install -r CodeFormer/requirements.txt")
26
+ run_command("pip install basicsr facexlib realesrgan opencv-python")
27
+ run_command("python CodeFormer/basicsr/setup.py develop")
 
 
 
28
 
29
+ # Add repository to path
30
  sys.path.append('CodeFormer')
31
 
32
+ # Create directories for model weights
33
+ os.makedirs('CodeFormer/weights/CodeFormer', exist_ok=True)
34
+ os.makedirs('CodeFormer/weights/facelib', exist_ok=True)
 
 
35
 
36
+ # Download model weights
37
+ def download_file(url, save_path):
38
+ if not os.path.exists(save_path):
39
+ print(f"Downloading {url} to {save_path}")
40
+ response = requests.get(url)
41
+ if response.status_code == 200:
42
+ with open(save_path, 'wb') as f:
43
+ f.write(response.content)
44
+ print(f"Downloaded {save_path}")
45
+ else:
46
+ print(f"Failed to download {url}, status code: {response.status_code}")
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Download required model weights
49
+ download_file(
50
+ "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
51
+ "CodeFormer/weights/CodeFormer/codeformer.pth"
52
+ )
53
+ download_file(
54
+ "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
55
+ "CodeFormer/weights/facelib/detection_Resnet50_Final.pth"
56
+ )
57
+ download_file(
58
+ "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
59
+ "CodeFormer/weights/facelib/parsing_parsenet.pth"
60
+ )
61
 
62
+ # Import CodeFormer modules
63
+ try:
64
+ from basicsr.archs.codeformer_arch import CodeFormer
65
+ from basicsr.utils.registry import ARCH_REGISTRY
66
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
67
+ from torchvision.transforms.functional import normalize
68
+ except ImportError:
69
+ print("Error importing CodeFormer modules. Make sure all dependencies are installed correctly.")
70
+ # Try to install missing dependencies
71
+ run_command("cd CodeFormer && pip install -e .")
72
+ from basicsr.archs.codeformer_arch import CodeFormer
73
+ from basicsr.utils.registry import ARCH_REGISTRY
74
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
75
+ from torchvision.transforms.functional import normalize
76
+
77
+ # Load model function
78
+ def load_model():
79
+ print('Loading CodeFormer model on CPU...')
80
  device = torch.device('cpu')
 
 
 
 
81
 
82
+ # Initialize model
83
+ model = ARCH_REGISTRY.get('CodeFormer')(
84
  dim_embd=512,
85
  codebook_size=1024,
86
  n_head=8,
 
88
  connect_list=['32', '64', '128', '256']
89
  ).to(device)
90
 
91
+ # Load checkpoint
92
  ckpt_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
93
  checkpoint = torch.load(ckpt_path, map_location=device)
94
 
95
  if 'params_ema' in checkpoint:
96
+ model.load_state_dict(checkpoint['params_ema'])
97
  else:
98
+ model.load_state_dict(checkpoint['params'])
99
 
100
+ model.eval()
101
 
102
+ # Initialize face helper
103
  face_helper = FaceRestoreHelper(
104
  upscale_factor=1,
105
  face_size=512,
 
110
  device=device
111
  )
112
 
113
+ return model, face_helper
114
+
115
+ # Image conversion utilities
116
+ def img2tensor(img, bgr2rgb=True, float32=True):
117
+ img = img.astype(np.float32) if float32 else img
118
+ if bgr2rgb:
119
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
120
+ img = torch.from_numpy(img.transpose(2, 0, 1))
121
+ return img
122
+
123
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
124
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
125
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
126
+
127
+ n_dim = tensor.dim()
128
+ if n_dim == 3:
129
+ img_np = tensor.numpy()
130
+ img_np = img_np.transpose(1, 2, 0)
131
+ if rgb2bgr:
132
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
133
+ elif n_dim == 2:
134
+ img_np = tensor.numpy()
135
+ else:
136
+ raise TypeError(f'Only support 3D and 2D tensor. But got {n_dim}D tensor.')
137
+
138
+ if out_type == np.uint8:
139
+ img_np = (img_np * 255.0).round().astype(np.uint8)
140
+
141
+ return img_np
142
 
143
+ # Process image function
144
+ def process(image, w_value=0.5, has_aligned=False):
145
+ if image is None:
146
+ return None, "Please upload an image."
147
+
148
  device = torch.device('cpu')
149
+ model, face_helper = load_model()
150
 
151
+ # Convert PIL image to numpy array if needed
152
  if isinstance(image, Image.Image):
153
  img = np.array(image)
154
  else:
155
  img = image
156
+
157
+ # Process aligned face or detect faces
158
  if has_aligned:
 
159
  face_helper.is_gray = len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1)
160
  if face_helper.is_gray:
161
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
162
  face_helper.cropped_faces = [img]
163
  else:
 
164
  face_helper.clean_all()
165
  face_helper.read_image(img)
166
  face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
167
  face_helper.align_warp_face()
168
 
169
+ # Check if face was detected
170
  if len(face_helper.cropped_faces) == 0:
171
+ return img, "No face detected in the image!"
172
 
173
+ # Process each face
174
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
 
175
  cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
176
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
177
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
178
 
179
  try:
180
  with torch.no_grad():
181
+ output = model(cropped_face_t, w=w_value, adain=True)[0]
 
182
  restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
183
+
184
+ # Free up memory
185
  del output
186
+ except Exception as e:
187
+ print(f'Error: {e}')
 
188
  restored_face = cropped_face
189
 
 
190
  face_helper.add_restored_face(restored_face)
191
 
192
+ # Get final result
193
  if not has_aligned:
 
194
  face_helper.get_inverse_affine(None)
195
  restored_img = face_helper.paste_faces_to_input_image()
 
196
  else:
197
+ restored_img = face_helper.restored_faces[0]
198
 
199
+ # Return result as PIL image
200
+ return Image.fromarray(restored_img), "Face restoration complete!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # Create Gradio interface
203
+ demo = gr.Interface(
204
+ fn=process,
205
+ inputs=[
206
+ gr.Image(type="pil", label="Input Image"),
207
+ gr.Slider(0, 1, 0.5, step=0.01, label="Fidelity Weight (w) - Lower for quality, Higher for identity"),
208
+ gr.Checkbox(label="Is input an aligned face?", value=False)
209
+ ],
210
+ outputs=[
211
+ gr.Image(type="pil", label="Restored Image"),
212
+ gr.Textbox(label="Status")
213
+ ],
214
+ title="CodeFormer Face Restoration (CPU Version)",
215
+ description="Upload a photo with faces to restore the quality. Running on CPU, so please be patient!"
216
+ )
217
 
218
+ # Launch the app
219
+ demo.launch()