import os import cv2 import numpy as np import torch import urllib.request from tqdm import tqdm # Fix for the torchvision import error def fix_torchvision_issue(): """Create a workaround for the torchvision.transforms.functional_tensor issue""" import torchvision import sys # Check if the problematic module exists if not hasattr(torchvision.transforms, 'functional_tensor'): # Create the missing module class FunctionalTensorModule: def rgb_to_grayscale(self, img): # Simple implementation of rgb_to_grayscale if len(img.shape) == 4: # batch of images return (img[:, 0, ...] * 0.2989 + img[:, 1, ...] * 0.5870 + img[:, 2, ...] * 0.1140).unsqueeze(1) else: # single image return (img[0, ...] * 0.2989 + img[1, ...] * 0.5870 + img[2, ...] * 0.1140).unsqueeze(0) # Add the module to torchvision.transforms torchvision.transforms.functional_tensor = FunctionalTensorModule() # Add it to sys.modules to ensure imports work sys.modules['torchvision.transforms.functional_tensor'] = torchvision.transforms.functional_tensor print("Added compatibility layer for torchvision.transforms.functional_tensor") def install_dependencies(): """Install all required dependencies with specific versions to avoid conflicts""" print("Installing dependencies...") # Install specific versions known to work together packages = [ "torch==1.12.1", "torchvision==0.13.1", "basicsr>=1.4.2", "facexlib>=0.2.5", "gfpgan>=1.3.8", "opencv-python", "tqdm" ] for package in packages: os.system(f"pip install {package}") # Fix torchvision issue after installation fix_torchvision_issue() print("Dependencies installed successfully") def download_model(): """Download the GFPGAN model if not already present""" os.makedirs('experiments/pretrained_models', exist_ok=True) model_path = 'experiments/pretrained_models/GFPGANv1.3.pth' if not os.path.exists(model_path): print("Downloading GFPGANv1.3 model...") url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' urllib.request.urlretrieve(url, model_path) print(f"Model downloaded to {model_path}") return model_path def setup_gfpgan(): """Set up GFPGAN with the required dependencies""" # Install dependencies if needed install_dependencies() # Import after installing dependencies from gfpgan import GFPGANer from basicsr.utils import imwrite # Download the model model_path = download_model() # Initialize GFPGAN for CPU usage device = torch.device('cpu') # Set up the restorer - note we're using CPU mode restorer = GFPGANer( model_path=model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, # No background upsampler for CPU device=device ) return restorer def process_image(restorer, img_path, output_dir='results'): """Process a single image with GFPGAN""" from basicsr.utils import imwrite os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True) os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True) # Read image img_name = os.path.basename(img_path) print(f'Processing {img_name} ...') basename, ext = os.path.splitext(img_name) input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) if input_img is None: print(f"Warning: Cannot read image {img_path}") return None # Restore faces and background try: cropped_faces, restored_faces, restored_img = restorer.enhance( input_img, has_aligned=False, only_center_face=False, paste_back=True) except RuntimeError as e: print(f"Error processing image: {e}") return None # Save restored image if restored_img is not None: extension = ext[1:] if ext else 'png' save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}') imwrite(restored_img, save_restore_path) return save_restore_path return None def create_gradio_app(): """Create a Gradio web interface for the GFPGAN model""" try: import gradio as gr except ImportError: os.system('pip install gradio') import gradio as gr # Set up GFPGAN restorer = setup_gfpgan() def process_image_gradio(image): if image is None: return None # Save input image temporarily temp_input = 'temp_input.jpg' cv2.imwrite(temp_input, image[:, :, ::-1]) # Convert RGB to BGR for OpenCV # Process the image output_path = process_image(restorer, temp_input, 'results') # Read the output image if output_path and os.path.exists(output_path): restored_img = cv2.imread(output_path) # Convert back to RGB for Gradio if restored_img is not None: return restored_img[:, :, ::-1] return image # Return original if processing failed # Create Gradio interface app = gr.Interface( fn=process_image_gradio, inputs=gr.Image(type="numpy"), outputs=gr.Image(), title="GFPGAN Face Restoration (CPU)", description="Upload an image to improve facial details with GFPGAN running on CPU" ) return app # For command-line usage def main(): import argparse import glob parser = argparse.ArgumentParser(description='GFPGAN for CPU') parser.add_argument('--input', type=str, default='inputs', help='Input image or folder') parser.add_argument('--output', type=str, default='results', help='Output folder') args = parser.parse_args() # Set up GFPGAN restorer = setup_gfpgan() # Process images input_path = args.input output_dir = args.output if os.path.isfile(input_path): # Single image process_image(restorer, input_path, output_dir) else: # Directory of images os.makedirs(input_path, exist_ok=True) img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g'))) for img_path in tqdm(img_list): process_image(restorer, img_path, output_dir) print(f'Results are saved in {output_dir}') if __name__ == '__main__': # Check if running in a Hugging Face Space if os.getenv('SPACE_ID'): app = create_gradio_app() app.launch() else: main()