CodeFormer / gfpgan_cpu.py
lucky0146's picture
Update gfpgan_cpu.py
800b2d2 verified
import os
import cv2
import numpy as np
import torch
import urllib.request
from tqdm import tqdm
# Fix for the torchvision import error
def fix_torchvision_issue():
"""Create a workaround for the torchvision.transforms.functional_tensor issue"""
import torchvision
import sys
# Check if the problematic module exists
if not hasattr(torchvision.transforms, 'functional_tensor'):
# Create the missing module
class FunctionalTensorModule:
def rgb_to_grayscale(self, img):
# Simple implementation of rgb_to_grayscale
if len(img.shape) == 4: # batch of images
return (img[:, 0, ...] * 0.2989 + img[:, 1, ...] * 0.5870 + img[:, 2, ...] * 0.1140).unsqueeze(1)
else: # single image
return (img[0, ...] * 0.2989 + img[1, ...] * 0.5870 + img[2, ...] * 0.1140).unsqueeze(0)
# Add the module to torchvision.transforms
torchvision.transforms.functional_tensor = FunctionalTensorModule()
# Add it to sys.modules to ensure imports work
sys.modules['torchvision.transforms.functional_tensor'] = torchvision.transforms.functional_tensor
print("Added compatibility layer for torchvision.transforms.functional_tensor")
def install_dependencies():
"""Install all required dependencies with specific versions to avoid conflicts"""
print("Installing dependencies...")
# Install specific versions known to work together
packages = [
"torch==1.12.1",
"torchvision==0.13.1",
"basicsr>=1.4.2",
"facexlib>=0.2.5",
"gfpgan>=1.3.8",
"opencv-python",
"tqdm"
]
for package in packages:
os.system(f"pip install {package}")
# Fix torchvision issue after installation
fix_torchvision_issue()
print("Dependencies installed successfully")
def download_model():
"""Download the GFPGAN model if not already present"""
os.makedirs('experiments/pretrained_models', exist_ok=True)
model_path = 'experiments/pretrained_models/GFPGANv1.3.pth'
if not os.path.exists(model_path):
print("Downloading GFPGANv1.3 model...")
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
urllib.request.urlretrieve(url, model_path)
print(f"Model downloaded to {model_path}")
return model_path
def setup_gfpgan():
"""Set up GFPGAN with the required dependencies"""
# Install dependencies if needed
install_dependencies()
# Import after installing dependencies
from gfpgan import GFPGANer
from basicsr.utils import imwrite
# Download the model
model_path = download_model()
# Initialize GFPGAN for CPU usage
device = torch.device('cpu')
# Set up the restorer - note we're using CPU mode
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=None, # No background upsampler for CPU
device=device
)
return restorer
def process_image(restorer, img_path, output_dir='results'):
"""Process a single image with GFPGAN"""
from basicsr.utils import imwrite
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True)
# Read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if input_img is None:
print(f"Warning: Cannot read image {img_path}")
return None
# Restore faces and background
try:
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as e:
print(f"Error processing image: {e}")
return None
# Save restored image
if restored_img is not None:
extension = ext[1:] if ext else 'png'
save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}')
imwrite(restored_img, save_restore_path)
return save_restore_path
return None
def create_gradio_app():
"""Create a Gradio web interface for the GFPGAN model"""
try:
import gradio as gr
except ImportError:
os.system('pip install gradio')
import gradio as gr
# Set up GFPGAN
restorer = setup_gfpgan()
def process_image_gradio(image):
if image is None:
return None
# Save input image temporarily
temp_input = 'temp_input.jpg'
cv2.imwrite(temp_input, image[:, :, ::-1]) # Convert RGB to BGR for OpenCV
# Process the image
output_path = process_image(restorer, temp_input, 'results')
# Read the output image
if output_path and os.path.exists(output_path):
restored_img = cv2.imread(output_path)
# Convert back to RGB for Gradio
if restored_img is not None:
return restored_img[:, :, ::-1]
return image # Return original if processing failed
# Create Gradio interface
app = gr.Interface(
fn=process_image_gradio,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(),
title="GFPGAN Face Restoration (CPU)",
description="Upload an image to improve facial details with GFPGAN running on CPU"
)
return app
# For command-line usage
def main():
import argparse
import glob
parser = argparse.ArgumentParser(description='GFPGAN for CPU')
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('--output', type=str, default='results', help='Output folder')
args = parser.parse_args()
# Set up GFPGAN
restorer = setup_gfpgan()
# Process images
input_path = args.input
output_dir = args.output
if os.path.isfile(input_path):
# Single image
process_image(restorer, input_path, output_dir)
else:
# Directory of images
os.makedirs(input_path, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g')))
for img_path in tqdm(img_list):
process_image(restorer, img_path, output_dir)
print(f'Results are saved in {output_dir}')
if __name__ == '__main__':
# Check if running in a Hugging Face Space
if os.getenv('SPACE_ID'):
app = create_gradio_app()
app.launch()
else:
main()