lucky0146 commited on
Commit
e187be5
·
verified ·
1 Parent(s): fea84be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -213
app.py CHANGED
@@ -1,219 +1,17 @@
1
- import os
2
- import sys
3
- import torch
4
  import gradio as gr
5
- import numpy as np
6
- import requests
7
- from PIL import Image
8
  import cv2
9
- import subprocess
10
 
11
- # Function to run shell commands
12
- def run_command(command):
13
- process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
14
- stdout, stderr = process.communicate()
15
- if process.returncode != 0:
16
- print(f"Error executing command: {command}")
17
- print(stderr.decode())
18
- else:
19
- print(stdout.decode())
20
- return process.returncode
21
-
22
- # Clone repository and install dependencies
23
- if not os.path.exists('CodeFormer'):
24
- run_command("git clone https://github.com/sczhou/CodeFormer.git")
25
- run_command("pip install -r CodeFormer/requirements.txt")
26
- run_command("pip install basicsr facexlib realesrgan opencv-python")
27
- run_command("python CodeFormer/basicsr/setup.py develop")
28
-
29
- # Add repository to path
30
- sys.path.append('CodeFormer')
31
-
32
- # Create directories for model weights
33
- os.makedirs('CodeFormer/weights/CodeFormer', exist_ok=True)
34
- os.makedirs('CodeFormer/weights/facelib', exist_ok=True)
35
-
36
- # Download model weights
37
- def download_file(url, save_path):
38
- if not os.path.exists(save_path):
39
- print(f"Downloading {url} to {save_path}")
40
- response = requests.get(url)
41
- if response.status_code == 200:
42
- with open(save_path, 'wb') as f:
43
- f.write(response.content)
44
- print(f"Downloaded {save_path}")
45
- else:
46
- print(f"Failed to download {url}, status code: {response.status_code}")
47
-
48
- # Download required model weights
49
- download_file(
50
- "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
51
- "CodeFormer/weights/CodeFormer/codeformer.pth"
52
- )
53
- download_file(
54
- "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
55
- "CodeFormer/weights/facelib/detection_Resnet50_Final.pth"
56
- )
57
- download_file(
58
- "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
59
- "CodeFormer/weights/facelib/parsing_parsenet.pth"
60
  )
61
 
62
- # Import CodeFormer modules
63
- try:
64
- from basicsr.archs.codeformer_arch import CodeFormer
65
- from basicsr.utils.registry import ARCH_REGISTRY
66
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
67
- from torchvision.transforms.functional import normalize
68
- except ImportError:
69
- print("Error importing CodeFormer modules. Make sure all dependencies are installed correctly.")
70
- # Try to install missing dependencies
71
- run_command("cd CodeFormer && pip install -e .")
72
- from basicsr.archs.codeformer_arch import CodeFormer
73
- from basicsr.utils.registry import ARCH_REGISTRY
74
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
75
- from torchvision.transforms.functional import normalize
76
-
77
- # Load model function
78
- def load_model():
79
- print('Loading CodeFormer model on CPU...')
80
- device = torch.device('cpu')
81
-
82
- # Initialize model
83
- model = ARCH_REGISTRY.get('CodeFormer')(
84
- dim_embd=512,
85
- codebook_size=1024,
86
- n_head=8,
87
- n_layers=9,
88
- connect_list=['32', '64', '128', '256']
89
- ).to(device)
90
-
91
- # Load checkpoint
92
- ckpt_path = 'CodeFormer/weights/CodeFormer/codeformer.pth'
93
- checkpoint = torch.load(ckpt_path, map_location=device)
94
-
95
- if 'params_ema' in checkpoint:
96
- model.load_state_dict(checkpoint['params_ema'])
97
- else:
98
- model.load_state_dict(checkpoint['params'])
99
-
100
- model.eval()
101
-
102
- # Initialize face helper
103
- face_helper = FaceRestoreHelper(
104
- upscale_factor=1,
105
- face_size=512,
106
- crop_ratio=(1, 1),
107
- det_model='retinaface_resnet50',
108
- save_ext='png',
109
- use_parse=True,
110
- device=device
111
- )
112
-
113
- return model, face_helper
114
-
115
- # Image conversion utilities
116
- def img2tensor(img, bgr2rgb=True, float32=True):
117
- img = img.astype(np.float32) if float32 else img
118
- if bgr2rgb:
119
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
120
- img = torch.from_numpy(img.transpose(2, 0, 1))
121
- return img
122
-
123
- def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
124
- tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
125
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
126
-
127
- n_dim = tensor.dim()
128
- if n_dim == 3:
129
- img_np = tensor.numpy()
130
- img_np = img_np.transpose(1, 2, 0)
131
- if rgb2bgr:
132
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
133
- elif n_dim == 2:
134
- img_np = tensor.numpy()
135
- else:
136
- raise TypeError(f'Only support 3D and 2D tensor. But got {n_dim}D tensor.')
137
-
138
- if out_type == np.uint8:
139
- img_np = (img_np * 255.0).round().astype(np.uint8)
140
-
141
- return img_np
142
-
143
- # Process image function
144
- def process(image, w_value=0.5, has_aligned=False):
145
- if image is None:
146
- return None, "Please upload an image."
147
-
148
- device = torch.device('cpu')
149
- model, face_helper = load_model()
150
-
151
- # Convert PIL image to numpy array if needed
152
- if isinstance(image, Image.Image):
153
- img = np.array(image)
154
- else:
155
- img = image
156
-
157
- # Process aligned face or detect faces
158
- if has_aligned:
159
- face_helper.is_gray = len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1)
160
- if face_helper.is_gray:
161
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
162
- face_helper.cropped_faces = [img]
163
- else:
164
- face_helper.clean_all()
165
- face_helper.read_image(img)
166
- face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
167
- face_helper.align_warp_face()
168
-
169
- # Check if face was detected
170
- if len(face_helper.cropped_faces) == 0:
171
- return img, "No face detected in the image!"
172
-
173
- # Process each face
174
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
175
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
176
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
177
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
178
-
179
- try:
180
- with torch.no_grad():
181
- output = model(cropped_face_t, w=w_value, adain=True)[0]
182
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
183
-
184
- # Free up memory
185
- del output
186
- except Exception as e:
187
- print(f'Error: {e}')
188
- restored_face = cropped_face
189
-
190
- face_helper.add_restored_face(restored_face)
191
-
192
- # Get final result
193
- if not has_aligned:
194
- face_helper.get_inverse_affine(None)
195
- restored_img = face_helper.paste_faces_to_input_image()
196
- else:
197
- restored_img = face_helper.restored_faces[0]
198
-
199
- # Return result as PIL image
200
- return Image.fromarray(restored_img), "Face restoration complete!"
201
-
202
- # Create Gradio interface
203
- demo = gr.Interface(
204
- fn=process,
205
- inputs=[
206
- gr.Image(type="pil", label="Input Image"),
207
- gr.Slider(0, 1, 0.5, step=0.01, label="Fidelity Weight (w) - Lower for quality, Higher for identity"),
208
- gr.Checkbox(label="Is input an aligned face?", value=False)
209
- ],
210
- outputs=[
211
- gr.Image(type="pil", label="Restored Image"),
212
- gr.Textbox(label="Status")
213
- ],
214
- title="CodeFormer Face Restoration (CPU Version)",
215
- description="Upload a photo with faces to restore the quality. Running on CPU, so please be patient!"
216
- )
217
 
218
- # Launch the app
219
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ from gfpgan import GFPGANer
 
 
3
  import cv2
 
4
 
5
+ restorer = GFPGANer(
6
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.3.pth',
7
+ upscale=1,
8
+ arch='clean',
9
+ channel_multiplier=2,
10
+ bg_upsampler=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
 
13
+ def restore_image(img):
14
+ cropped_faces, restored_faces, restored_img = restorer.enhance(img, has_aligned=False, only_center_face=False)
15
+ return restored_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ gr.Interface(fn=restore_image, inputs=gr.Image(type='numpy'), outputs=gr.Image(type='numpy')).launch()