Spaces:
Runtime error
Runtime error
File size: 6,863 Bytes
731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e 800b2d2 731861e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os
import cv2
import numpy as np
import torch
import urllib.request
from tqdm import tqdm
# Fix for the torchvision import error
def fix_torchvision_issue():
"""Create a workaround for the torchvision.transforms.functional_tensor issue"""
import torchvision
import sys
# Check if the problematic module exists
if not hasattr(torchvision.transforms, 'functional_tensor'):
# Create the missing module
class FunctionalTensorModule:
def rgb_to_grayscale(self, img):
# Simple implementation of rgb_to_grayscale
if len(img.shape) == 4: # batch of images
return (img[:, 0, ...] * 0.2989 + img[:, 1, ...] * 0.5870 + img[:, 2, ...] * 0.1140).unsqueeze(1)
else: # single image
return (img[0, ...] * 0.2989 + img[1, ...] * 0.5870 + img[2, ...] * 0.1140).unsqueeze(0)
# Add the module to torchvision.transforms
torchvision.transforms.functional_tensor = FunctionalTensorModule()
# Add it to sys.modules to ensure imports work
sys.modules['torchvision.transforms.functional_tensor'] = torchvision.transforms.functional_tensor
print("Added compatibility layer for torchvision.transforms.functional_tensor")
def install_dependencies():
"""Install all required dependencies with specific versions to avoid conflicts"""
print("Installing dependencies...")
# Install specific versions known to work together
packages = [
"torch==1.12.1",
"torchvision==0.13.1",
"basicsr>=1.4.2",
"facexlib>=0.2.5",
"gfpgan>=1.3.8",
"opencv-python",
"tqdm"
]
for package in packages:
os.system(f"pip install {package}")
# Fix torchvision issue after installation
fix_torchvision_issue()
print("Dependencies installed successfully")
def download_model():
"""Download the GFPGAN model if not already present"""
os.makedirs('experiments/pretrained_models', exist_ok=True)
model_path = 'experiments/pretrained_models/GFPGANv1.3.pth'
if not os.path.exists(model_path):
print("Downloading GFPGANv1.3 model...")
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
urllib.request.urlretrieve(url, model_path)
print(f"Model downloaded to {model_path}")
return model_path
def setup_gfpgan():
"""Set up GFPGAN with the required dependencies"""
# Install dependencies if needed
install_dependencies()
# Import after installing dependencies
from gfpgan import GFPGANer
from basicsr.utils import imwrite
# Download the model
model_path = download_model()
# Initialize GFPGAN for CPU usage
device = torch.device('cpu')
# Set up the restorer - note we're using CPU mode
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=None, # No background upsampler for CPU
device=device
)
return restorer
def process_image(restorer, img_path, output_dir='results'):
"""Process a single image with GFPGAN"""
from basicsr.utils import imwrite
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'restored_faces'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'restored_imgs'), exist_ok=True)
# Read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if input_img is None:
print(f"Warning: Cannot read image {img_path}")
return None
# Restore faces and background
try:
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=False, only_center_face=False, paste_back=True)
except RuntimeError as e:
print(f"Error processing image: {e}")
return None
# Save restored image
if restored_img is not None:
extension = ext[1:] if ext else 'png'
save_restore_path = os.path.join(output_dir, 'restored_imgs', f'{basename}.{extension}')
imwrite(restored_img, save_restore_path)
return save_restore_path
return None
def create_gradio_app():
"""Create a Gradio web interface for the GFPGAN model"""
try:
import gradio as gr
except ImportError:
os.system('pip install gradio')
import gradio as gr
# Set up GFPGAN
restorer = setup_gfpgan()
def process_image_gradio(image):
if image is None:
return None
# Save input image temporarily
temp_input = 'temp_input.jpg'
cv2.imwrite(temp_input, image[:, :, ::-1]) # Convert RGB to BGR for OpenCV
# Process the image
output_path = process_image(restorer, temp_input, 'results')
# Read the output image
if output_path and os.path.exists(output_path):
restored_img = cv2.imread(output_path)
# Convert back to RGB for Gradio
if restored_img is not None:
return restored_img[:, :, ::-1]
return image # Return original if processing failed
# Create Gradio interface
app = gr.Interface(
fn=process_image_gradio,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(),
title="GFPGAN Face Restoration (CPU)",
description="Upload an image to improve facial details with GFPGAN running on CPU"
)
return app
# For command-line usage
def main():
import argparse
import glob
parser = argparse.ArgumentParser(description='GFPGAN for CPU')
parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('--output', type=str, default='results', help='Output folder')
args = parser.parse_args()
# Set up GFPGAN
restorer = setup_gfpgan()
# Process images
input_path = args.input
output_dir = args.output
if os.path.isfile(input_path):
# Single image
process_image(restorer, input_path, output_dir)
else:
# Directory of images
os.makedirs(input_path, exist_ok=True)
img_list = sorted(glob.glob(os.path.join(input_path, '*.[jp][pn]g')))
for img_path in tqdm(img_list):
process_image(restorer, img_path, output_dir)
print(f'Results are saved in {output_dir}')
if __name__ == '__main__':
# Check if running in a Hugging Face Space
if os.getenv('SPACE_ID'):
app = create_gradio_app()
app.launch()
else:
main()
|