lucky0146 commited on
Commit
731861e
·
verified ·
1 Parent(s): 8eb8d49

Create gfpgan_cpu.py

Browse files
Files changed (1) hide show
  1. gfpgan_cpu.py +185 -0
gfpgan_cpu.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import cv2
4
+ import glob
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+ from basicsr.utils import imwrite
10
+
11
+ # This is a simple implementation of GFPGAN for CPU usage on Hugging Face
12
+ def download_model():
13
+ """Download the GFPGAN model if not already present"""
14
+ import urllib.request
15
+ import os
16
+
17
+ os.makedirs('experiments/pretrained_models', exist_ok=True)
18
+ model_path = 'experiments/pretrained_models/GFPGANv1.3.pth'
19
+
20
+ if not os.path.exists(model_path):
21
+ print("Downloading GFPGANv1.3 model...")
22
+ url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
23
+ urllib.request.urlretrieve(url, model_path)
24
+ print(f"Model downloaded to {model_path}")
25
+
26
+ return model_path
27
+
28
+ def setup_gfpgan():
29
+ """Set up GFPGAN with the required dependencies"""
30
+ # Install required packages if not already installed
31
+ try:
32
+ import basicsr
33
+ except ImportError:
34
+ print("Installing basicsr...")
35
+ os.system('pip install basicsr')
36
+
37
+ try:
38
+ import facexlib
39
+ except ImportError:
40
+ print("Installing facexlib...")
41
+ os.system('pip install facexlib')
42
+
43
+ try:
44
+ import gfpgan
45
+ except ImportError:
46
+ print("Installing GFPGAN...")
47
+ os.system('pip install gfpgan')
48
+
49
+ from gfpgan import GFPGANer
50
+
51
+ # Download the model
52
+ model_path = download_model()
53
+
54
+ # Initialize GFPGAN for CPU usage
55
+ device = torch.device('cpu')
56
+
57
+ # Set up the restorer - note we're using CPU mode with half=False
58
+ restorer = GFPGANer(
59
+ model_path=model_path,
60
+ upscale=2,
61
+ arch='clean',
62
+ channel_multiplier=2,
63
+ bg_upsampler=None, # No background upsampler for CPU
64
+ device=device
65
+ )
66
+
67
+ return restorer
68
+
69
+ def process_image(restorer, img_path, output_dir='results'):
70
+ """Process a single image with GFPGAN"""
71
+ os.makedirs(output_dir, exist_ok=True)
72
+ os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True)
73
+ os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True)
74
+ os.makedirs(os.path.join(output_dir, 'cmp'), exist_ok=True)
75
+
76
+ # Read image
77
+ img_name = os.path.basename(img_path)
78
+ print(f'Processing {img_name} ...')
79
+
80
+ basename, ext = os.path.splitext(img_name)
81
+ input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
82
+
83
+ if input_img is None:
84
+ print(f"Warning: Cannot read image {img_path}")
85
+ return
86
+
87
+ # Restore faces and background
88
+ cropped_faces, restored_faces, restored_img = restorer.enhance(
89
+ input_img, has_aligned=False, only_center_face=False, paste_back=True)
90
+
91
+ # Save faces
92
+ for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
93
+ # Save restored face
94
+ save_face_name = f'{basename}_{idx:02d}.png'
95
+ save_restore_path = os.path.join(output_dir, 'restored_faces', save_face_name)
96
+ imwrite(restored_face, save_restore_path)
97
+
98
+ # Save comparison image
99
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
100
+ imwrite(cmp_img, os.path.join(output_dir, 'cmp', f'{basename}_{idx:02d}.png'))
101
+
102
+ # Save restored image
103
+ if restored_img is not None:
104
+ extension = ext[1:] if ext else 'png'
105
+ save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}')
106
+ imwrite(restored_img, save_restore_path)
107
+
108
+ return os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}')
109
+
110
+ def main():
111
+ """Main function to run GFPGAN on CPU"""
112
+ parser = argparse.ArgumentParser(description='GFPGAN for CPU')
113
+ parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
114
+ parser.add_argument('--output', type=str, default='results', help='Output folder')
115
+ args = parser.parse_args()
116
+
117
+ # Set up GFPGAN
118
+ restorer = setup_gfpgan()
119
+
120
+ # Process images
121
+ input_path = args.input
122
+ output_dir = args.output
123
+
124
+ if os.path.isfile(input_path):
125
+ # Single image
126
+ process_image(restorer, input_path, output_dir)
127
+ else:
128
+ # Directory of images
129
+ os.makedirs(input_path, exist_ok=True)
130
+ img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g')))
131
+ for img_path in tqdm(img_list):
132
+ process_image(restorer, img_path, output_dir)
133
+
134
+ print(f'Results are saved in {output_dir}')
135
+
136
+ # For Hugging Face Spaces (Gradio interface)
137
+ def create_gradio_app():
138
+ import gradio as gr
139
+
140
+ restorer = setup_gfpgan()
141
+
142
+ def process_image_gradio(image):
143
+ # Save input image temporarily
144
+ temp_input = 'temp_input.jpg'
145
+ cv2.imwrite(temp_input, image[:, :, ::-1]) # Convert RGB to BGR for OpenCV
146
+
147
+ # Process the image
148
+ output_path = process_image(restorer, temp_input, 'results')
149
+
150
+ # Read the output image
151
+ restored_img = cv2.imread(output_path)
152
+
153
+ # Convert back to RGB for Gradio
154
+ if restored_img is not None:
155
+ restored_img = restored_img[:, :, ::-1]
156
+ return restored_img
157
+ else:
158
+ return image # Return original if processing failed
159
+
160
+ # Create Gradio interface
161
+ app = gr.Interface(
162
+ fn=process_image_gradio,
163
+ inputs=gr.Image(),
164
+ outputs=gr.Image(),
165
+ title="GFPGAN - Face Restoration",
166
+ description="Upload an image to improve facial details with GFPGAN running on CPU"
167
+ )
168
+
169
+ return app
170
+
171
+ if __name__ == '__main__':
172
+ import sys
173
+
174
+ # Check if running in a Hugging Face Space
175
+ if os.getenv('SPACE_ID'):
176
+ try:
177
+ import gradio as gr
178
+ except ImportError:
179
+ os.system('pip install gradio')
180
+ import gradio as gr
181
+
182
+ app = create_gradio_app()
183
+ app.launch()
184
+ else:
185
+ main()