Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import urllib.request | |
| from tqdm import tqdm | |
| # Fix for the torchvision import error | |
| def fix_torchvision_issue(): | |
| """Create a workaround for the torchvision.transforms.functional_tensor issue""" | |
| import torchvision | |
| import sys | |
| # Check if the problematic module exists | |
| if not hasattr(torchvision.transforms, 'functional_tensor'): | |
| # Create the missing module | |
| class FunctionalTensorModule: | |
| def rgb_to_grayscale(self, img): | |
| # Simple implementation of rgb_to_grayscale | |
| if len(img.shape) == 4: # batch of images | |
| return (img[:, 0, ...] * 0.2989 + img[:, 1, ...] * 0.5870 + img[:, 2, ...] * 0.1140).unsqueeze(1) | |
| else: # single image | |
| return (img[0, ...] * 0.2989 + img[1, ...] * 0.5870 + img[2, ...] * 0.1140).unsqueeze(0) | |
| # Add the module to torchvision.transforms | |
| torchvision.transforms.functional_tensor = FunctionalTensorModule() | |
| # Add it to sys.modules to ensure imports work | |
| sys.modules['torchvision.transforms.functional_tensor'] = torchvision.transforms.functional_tensor | |
| print("Added compatibility layer for torchvision.transforms.functional_tensor") | |
| def install_dependencies(): | |
| """Install all required dependencies with specific versions to avoid conflicts""" | |
| print("Installing dependencies...") | |
| # Install specific versions known to work together | |
| packages = [ | |
| "torch==1.12.1", | |
| "torchvision==0.13.1", | |
| "basicsr>=1.4.2", | |
| "facexlib>=0.2.5", | |
| "gfpgan>=1.3.8", | |
| "opencv-python", | |
| "tqdm" | |
| ] | |
| for package in packages: | |
| os.system(f"pip install {package}") | |
| # Fix torchvision issue after installation | |
| fix_torchvision_issue() | |
| print("Dependencies installed successfully") | |
| def download_model(): | |
| """Download the GFPGAN model if not already present""" | |
| os.makedirs('experiments/pretrained_models', exist_ok=True) | |
| model_path = 'experiments/pretrained_models/GFPGANv1.3.pth' | |
| if not os.path.exists(model_path): | |
| print("Downloading GFPGANv1.3 model...") | |
| url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' | |
| urllib.request.urlretrieve(url, model_path) | |
| print(f"Model downloaded to {model_path}") | |
| return model_path | |
| def setup_gfpgan(): | |
| """Set up GFPGAN with the required dependencies""" | |
| # Install dependencies if needed | |
| install_dependencies() | |
| # Import after installing dependencies | |
| from gfpgan import GFPGANer | |
| from basicsr.utils import imwrite | |
| # Download the model | |
| model_path = download_model() | |
| # Initialize GFPGAN for CPU usage | |
| device = torch.device('cpu') | |
| # Set up the restorer - note we're using CPU mode | |
| restorer = GFPGANer( | |
| model_path=model_path, | |
| upscale=2, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=None, # No background upsampler for CPU | |
| device=device | |
| ) | |
| return restorer | |
| def process_image(restorer, img_path, output_dir='results'): | |
| """Process a single image with GFPGAN""" | |
| from basicsr.utils import imwrite | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True) | |
| # Read image | |
| img_name = os.path.basename(img_path) | |
| print(f'Processing {img_name} ...') | |
| basename, ext = os.path.splitext(img_name) | |
| input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) | |
| if input_img is None: | |
| print(f"Warning: Cannot read image {img_path}") | |
| return None | |
| # Restore faces and background | |
| try: | |
| cropped_faces, restored_faces, restored_img = restorer.enhance( | |
| input_img, has_aligned=False, only_center_face=False, paste_back=True) | |
| except RuntimeError as e: | |
| print(f"Error processing image: {e}") | |
| return None | |
| # Save restored image | |
| if restored_img is not None: | |
| extension = ext[1:] if ext else 'png' | |
| save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}') | |
| imwrite(restored_img, save_restore_path) | |
| return save_restore_path | |
| return None | |
| def create_gradio_app(): | |
| """Create a Gradio web interface for the GFPGAN model""" | |
| try: | |
| import gradio as gr | |
| except ImportError: | |
| os.system('pip install gradio') | |
| import gradio as gr | |
| # Set up GFPGAN | |
| restorer = setup_gfpgan() | |
| def process_image_gradio(image): | |
| if image is None: | |
| return None | |
| # Save input image temporarily | |
| temp_input = 'temp_input.jpg' | |
| cv2.imwrite(temp_input, image[:, :, ::-1]) # Convert RGB to BGR for OpenCV | |
| # Process the image | |
| output_path = process_image(restorer, temp_input, 'results') | |
| # Read the output image | |
| if output_path and os.path.exists(output_path): | |
| restored_img = cv2.imread(output_path) | |
| # Convert back to RGB for Gradio | |
| if restored_img is not None: | |
| return restored_img[:, :, ::-1] | |
| return image # Return original if processing failed | |
| # Create Gradio interface | |
| app = gr.Interface( | |
| fn=process_image_gradio, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Image(), | |
| title="GFPGAN Face Restoration (CPU)", | |
| description="Upload an image to improve facial details with GFPGAN running on CPU" | |
| ) | |
| return app | |
| # For command-line usage | |
| def main(): | |
| import argparse | |
| import glob | |
| parser = argparse.ArgumentParser(description='GFPGAN for CPU') | |
| parser.add_argument('--input', type=str, default='inputs', help='Input image or folder') | |
| parser.add_argument('--output', type=str, default='results', help='Output folder') | |
| args = parser.parse_args() | |
| # Set up GFPGAN | |
| restorer = setup_gfpgan() | |
| # Process images | |
| input_path = args.input | |
| output_dir = args.output | |
| if os.path.isfile(input_path): | |
| # Single image | |
| process_image(restorer, input_path, output_dir) | |
| else: | |
| # Directory of images | |
| os.makedirs(input_path, exist_ok=True) | |
| img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g'))) | |
| for img_path in tqdm(img_list): | |
| process_image(restorer, img_path, output_dir) | |
| print(f'Results are saved in {output_dir}') | |
| if __name__ == '__main__': | |
| # Check if running in a Hugging Face Space | |
| if os.getenv('SPACE_ID'): | |
| app = create_gradio_app() | |
| app.launch() | |
| else: | |
| main() | |