File size: 6,863 Bytes
731861e
 
 
 
800b2d2
731861e
 
800b2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731861e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800b2d2
 
731861e
800b2d2
731861e
800b2d2
731861e
 
 
 
 
 
 
800b2d2
731861e
 
 
 
 
 
 
 
 
 
 
 
 
800b2d2
 
731861e
 
 
 
 
 
 
 
 
 
 
 
 
800b2d2
731861e
 
800b2d2
 
 
 
 
 
731861e
 
 
 
 
 
800b2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731861e
800b2d2
731861e
800b2d2
731861e
800b2d2
 
 
731861e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
import cv2
import numpy as np
import torch
import urllib.request
from tqdm import tqdm

# Fix for the torchvision import error
def fix_torchvision_issue():
    """Create a workaround for the torchvision.transforms.functional_tensor issue"""
    import torchvision
    import sys
    
    # Check if the problematic module exists
    if not hasattr(torchvision.transforms, 'functional_tensor'):
        # Create the missing module
        class FunctionalTensorModule:
            def rgb_to_grayscale(self, img):
                # Simple implementation of rgb_to_grayscale
                if len(img.shape) == 4:  # batch of images
                    return (img[:, 0, ...] * 0.2989 + img[:, 1, ...] * 0.5870 + img[:, 2, ...] * 0.1140).unsqueeze(1)
                else:  # single image
                    return (img[0, ...] * 0.2989 + img[1, ...] * 0.5870 + img[2, ...] * 0.1140).unsqueeze(0)
        
        # Add the module to torchvision.transforms
        torchvision.transforms.functional_tensor = FunctionalTensorModule()
        
        # Add it to sys.modules to ensure imports work
        sys.modules['torchvision.transforms.functional_tensor'] = torchvision.transforms.functional_tensor
        
        print("Added compatibility layer for torchvision.transforms.functional_tensor")

def install_dependencies():
    """Install all required dependencies with specific versions to avoid conflicts"""
    print("Installing dependencies...")
    
    # Install specific versions known to work together
    packages = [
        "torch==1.12.1",
        "torchvision==0.13.1",
        "basicsr>=1.4.2",
        "facexlib>=0.2.5",
        "gfpgan>=1.3.8",
        "opencv-python",
        "tqdm"
    ]
    
    for package in packages:
        os.system(f"pip install {package}")
    
    # Fix torchvision issue after installation
    fix_torchvision_issue()
    
    print("Dependencies installed successfully")

def download_model():
    """Download the GFPGAN model if not already present"""
    os.makedirs('experiments/pretrained_models', exist_ok=True)
    model_path = 'experiments/pretrained_models/GFPGANv1.3.pth'
    
    if not os.path.exists(model_path):
        print("Downloading GFPGANv1.3 model...")
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
        urllib.request.urlretrieve(url, model_path)
        print(f"Model downloaded to {model_path}")
    
    return model_path

def setup_gfpgan():
    """Set up GFPGAN with the required dependencies"""
    # Install dependencies if needed
    install_dependencies()
    
    # Import after installing dependencies
    from gfpgan import GFPGANer
    from basicsr.utils import imwrite
    
    # Download the model
    model_path = download_model()
    
    # Initialize GFPGAN for CPU usage
    device = torch.device('cpu')
    
    # Set up the restorer - note we're using CPU mode
    restorer = GFPGANer(
        model_path=model_path,
        upscale=2,
        arch='clean',
        channel_multiplier=2,
        bg_upsampler=None,  # No background upsampler for CPU
        device=device
    )
    
    return restorer

def process_image(restorer, img_path, output_dir='results'):
    """Process a single image with GFPGAN"""
    from basicsr.utils import imwrite
    
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True)
    
    # Read image
    img_name = os.path.basename(img_path)
    print(f'Processing {img_name} ...')
    
    basename, ext = os.path.splitext(img_name)
    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    
    if input_img is None:
        print(f"Warning: Cannot read image {img_path}")
        return None
    
    # Restore faces and background
    try:
        cropped_faces, restored_faces, restored_img = restorer.enhance(
            input_img, has_aligned=False, only_center_face=False, paste_back=True)
    except RuntimeError as e:
        print(f"Error processing image: {e}")
        return None
    
    # Save restored image
    if restored_img is not None:
        extension = ext[1:] if ext else 'png'
        save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}')
        imwrite(restored_img, save_restore_path)
        return save_restore_path
    
    return None

def create_gradio_app():
    """Create a Gradio web interface for the GFPGAN model"""
    try:
        import gradio as gr
    except ImportError:
        os.system('pip install gradio')
        import gradio as gr
    
    # Set up GFPGAN
    restorer = setup_gfpgan()
    
    def process_image_gradio(image):
        if image is None:
            return None
            
        # Save input image temporarily
        temp_input = 'temp_input.jpg'
        cv2.imwrite(temp_input, image[:, :, ::-1])  # Convert RGB to BGR for OpenCV
        
        # Process the image
        output_path = process_image(restorer, temp_input, 'results')
        
        # Read the output image
        if output_path and os.path.exists(output_path):
            restored_img = cv2.imread(output_path)
            # Convert back to RGB for Gradio
            if restored_img is not None:
                return restored_img[:, :, ::-1]
        
        return image  # Return original if processing failed
    
    # Create Gradio interface
    app = gr.Interface(
        fn=process_image_gradio,
        inputs=gr.Image(type="numpy"),
        outputs=gr.Image(),
        title="GFPGAN Face Restoration (CPU)",
        description="Upload an image to improve facial details with GFPGAN running on CPU"
    )
    
    return app

# For command-line usage
def main():
    import argparse
    import glob
    
    parser = argparse.ArgumentParser(description='GFPGAN for CPU')
    parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
    parser.add_argument('--output', type=str, default='results', help='Output folder')
    args = parser.parse_args()
    
    # Set up GFPGAN
    restorer = setup_gfpgan()
    
    # Process images
    input_path = args.input
    output_dir = args.output
    
    if os.path.isfile(input_path):
        # Single image
        process_image(restorer, input_path, output_dir)
    else:
        # Directory of images
        os.makedirs(input_path, exist_ok=True)
        img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g')))
        for img_path in tqdm(img_list):
            process_image(restorer, img_path, output_dir)
    
    print(f'Results are saved in {output_dir}')

if __name__ == '__main__':
    # Check if running in a Hugging Face Space
    if os.getenv('SPACE_ID'):
        app = create_gradio_app()
        app.launch()
    else:
        main()