thenightfury's picture
Upload 3 files
f349437 verified
import gradio as gr
import cv2
import numpy
import os
import random
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from torchvision.transforms.functional import rgb_to_grayscale
import spaces
# Global variables for file management and image mode
last_file = None
img_mode = "RGBA"
@spaces.GPU
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
"""
Real-ESRGAN function to restore (and upscale) images.
Args:
img (PIL.Image or numpy.ndarray): The input image.
model_name (str): The name of the Real-ESRGAN model to use.
denoise_strength (float): The strength of denoising for 'realesr-general-x4v3' model.
face_enhance (bool): Whether to apply face enhancement using GFPGAN.
outscale (int): The desired upscale factor for the output image.
Returns:
tuple: A tuple containing the path to the output image and its properties string.
Returns None if no image is provided or an error occurs.
"""
if img is None:
return None, "No image provided for upscaling."
# Define model parameters based on the selected model_name
if model_name == 'RealESRGAN_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRNet_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
elif model_name == 'realesr-general-x4v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
netscale = 4
file_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
]
else:
return None, "Invalid model name selected."
# Load model weights
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
for url in file_url:
model_path = load_file_from_url(
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Initialize RealESRGAN upsampler
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0, # Set to 0 for no tiling, or a positive integer for tile size
tile_pad=10,
pre_pad=10,
half=False,
gpu_id=None # Let RealESRGANer determine GPU if available
)
face_enhancer = None
if face_enhance:
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
# Convert input image to OpenCV format (BGRA)
cv_img = numpy.array(img)
# Ensure the image has an alpha channel if img_mode is RGBA, otherwise convert to BGR
if img_mode == "RGBA":
img_to_process = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
else:
img_to_process = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
try:
if face_enhance and face_enhancer:
# Enhance faces and upscale background
_, _, output = face_enhancer.enhance(img_to_process, has_aligned=False, only_center_face=False, paste_back=True)
else:
# Only upscale
output, _ = upsampler.enhance(img_to_process, outscale=outscale)
except RuntimeError as error:
print(f'Error during upscaling: {error}')
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
return None, f"Error during upscaling: {error}. Try reducing image size or upscale factor."
else:
extension = 'png' if img_mode == 'RGBA' else 'jpg'
out_filename = f"output_{rnd_string(8)}.{extension}"
cv2.imwrite(out_filename, output)
global last_file
last_file = out_filename
# Convert output image back to RGBA if original was RGBA, otherwise keep as is
output_img_display = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) if img_mode == "RGBA" else output
return out_filename, image_properties(output_img_display)
def rnd_string(x):
"""Generates a random string of length x."""
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
return "".join((random.choice(characters)) for i in range(x))
def reset():
"""Resets the input and output images and their properties, and deletes the last generated file."""
global last_file
if last_file and os.path.exists(last_file):
print(f"Deleting {last_file} ...")
os.remove(last_file)
last_file = None
return gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)
def has_transparency(img):
"""Checks if a PIL image has transparency."""
if img.info.get("transparency", None) is not None:
return True
if img.mode == "P":
transparent = img.info.get("transparency", -1)
for _, index in img.getcolors():
if index == transparent:
return True
elif img.mode == "RGBA":
extrema = img.getextrema()
if extrema[3][0] < 255: # Check if alpha channel has any value less than 255 (fully opaque)
return True
return False
def image_properties(img):
"""
Returns the dimensions (width and height) and color mode of the input image
and also sets the global img_mode variable to be used by the realesrgan function.
Args:
img (PIL.Image or numpy.ndarray): The input image.
Returns:
str: A string describing the image properties (resolution and color mode).
"""
global img_mode
if img is None:
return "No image data available."
if isinstance(img, numpy.ndarray):
height, width = img.shape[:2]
channels = img.shape[2] if len(img.shape) > 2 else 1
# Determine img_mode based on channels for numpy array
if channels == 4:
img_mode = "RGBA"
elif channels == 3:
img_mode = "RGB"
else:
img_mode = "Grayscale"
return f"Resolution: Width: {width}, Height: {height} | Color Mode: {img_mode}"
# For PIL images (which Gradio's gr.Image(type="pil") returns)
if hasattr(img, "info") and hasattr(img, "mode") and hasattr(img, "size"):
if has_transparency(img):
img_mode = "RGBA"
else:
img_mode = "RGB"
return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}" # Use img.mode directly here
return "Unsupported image format."
def main():
"""Main function to define and launch the Gradio interface."""
# Define a custom theme for a more modern look
custom_theme = gr.themes.Soft(
primary_hue="purple",
secondary_hue="indigo",
neutral_hue="gray",
text_size="lg",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
spacing_size="md",
radius_size="lg",
).set(
# General customizations for a sleek look
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
button_primary_text_color="*white",
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200",
button_secondary_text_color="*neutral_700",
panel_background_fill="*neutral_50",
block_background_fill="*white",
block_border_width="1px",
block_border_color="*neutral_200",
block_shadow="*shadow_lg", # Use standard shadow variable
input_background_fill="*white",
input_border_color="*neutral_300",
input_shadow="*shadow_sm",
)
with gr.Blocks(theme=custom_theme, title="Premium Image Upscaler") as app:
gr.Markdown(
"""
<div style="text-align: center; padding: 20px; background: linear-gradient(to right, #6a11cb 0%, #2575fc 100%); color: white; border-radius: 10px; margin-bottom: 30px; box-shadow: 0 8px 16px rgba(0,0,0,0.2);">
<h1 style="font-size: 3em; margin-bottom: 10px;">✨ Elite Image Enhancement ✨</h1>
<p style="font-size: 1.2em; opacity: 0.9;">
Unleash the full potential of your visuals with state-of-the-art AI upscaling.
Experience unparalleled clarity and detail.
</p>
</div>
"""
)
with gr.Row(variant="panel", elem_id="main-layout-row"):
with gr.Column(scale=1, min_width=300): # Controls column
with gr.Accordion("⚙️ Upscaling Parameters", open=True, elem_id="upscale-options-accordion"):
gr.Markdown(
"""
<p style="font-size: 1.1em; color: #555;">
Fine-tune the AI to achieve your desired image quality.
</p>
"""
)
model_name = gr.Dropdown(
label="Select Real-ESRGAN Model",
choices=[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"realesr-general-x4v3"
],
value="RealESRGAN_x4plus",
# info="Choose the base model optimized for various image types.", # Removed 'info'
elem_id="model-dropdown"
)
denoise_strength = gr.Slider(
label="Denoise Strength (for 'realesr-general-x4v3')",
minimum=0,
maximum=1,
step=0.1,
value=0.5,
# info="Controls the denoising level. Higher value = more denoising (only for 'realesr-general-x4v3').", # Removed 'info'
elem_id="denoise-slider"
)
outscale = gr.Slider(
label="Output Resolution Upscale Factor",
minimum=1,
maximum=6,
step=1,
value=4,
# info="Increases image resolution (e.g., 4x for 4 times wider/taller).", # Removed 'info'
elem_id="outscale-slider"
)
face_enhance = gr.Checkbox(
label="Enable Face Enhancement (GFPGAN)",
# info="Detects and enhances faces using GFPGAN alongside general upscaling.", # Removed 'info'
elem_id="face-enhance-checkbox"
)
with gr.Column(elem_id="action-buttons-column"):
gr.Markdown("<h3 style='text-align: center; margin-top: 20px; color: #333;'>Actions</h3>")
with gr.Row():
reset_btn = gr.Button("🔄 Reset All", variant="secondary", size="lg", elem_id="reset-button")
upscale_btn = gr.Button("🚀 Upscale Image", variant="primary", size="lg", elem_id="upscale-button")
with gr.Column(scale=2, min_width=500): # Image display column
gr.Markdown(
"""
<h3 style="text-align: center; margin-bottom: 20px; color: #333;">
<span style="color: #6a11cb;">Input</span> vs <span style="color: #2575fc;">Output</span>
</h3>
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Original Image",
type="pil",
interactive=True,
height=350, # Adjusted height for better fit in column layout
elem_id="input_image_upload",
# info="Drop your image here or click to upload." # Removed 'info'
)
input_properties = gr.Textbox(
label="Input Image Details",
interactive=False,
placeholder="Image properties will appear here.",
elem_id="input-properties-textbox"
)
with gr.Column():
output_image = gr.Image(
label="Enhanced Image",
interactive=False,
height=350, # Adjusted height
elem_id="output_image_display",
# info="Your high-resolution image will be displayed here." # Removed 'info'
)
output_properties = gr.Textbox(
label="Output Image Details",
interactive=False,
placeholder="Enhanced image properties will appear here.",
elem_id="output-properties-textbox"
)
gr.Markdown(
"""
<div style="text-align: center; padding: 15px; margin-top: 30px; background-color: #f0f2f5; border-radius: 8px; color: #666; font-size: 0.9em;">
<p>Powered by <a href="https://github.com/xinntao/Real-ESRGAN" target="_blank" style="color: #6a11cb; text-decoration: none;">Real-ESRGAN</a> and <a href="https://github.com/TencentARC/GFPGAN" target="_blank" style="color: #2575fc; text-decoration: none;">GFPGAN</a>.</p>
<p>Built with ❤️ using <a href="https://gradio.app/" target="_blank" style="color: #666; text-decoration: none;">Gradio</a>.</p>
</div>
"""
)
# Event listeners
input_image.change(fn=image_properties, inputs=input_image, outputs=input_properties)
upscale_btn.click(
fn=realesrgan,
inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
outputs=[output_image, output_properties]
)
reset_btn.click(
fn=reset,
inputs=[],
outputs=[input_image, output_image, input_properties, output_properties]
)
app.launch(share=True)
if __name__ == "__main__":
main()