Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import cv2 | |
| import numpy | |
| import os | |
| import random | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.download_util import load_file_from_url | |
| from realesrgan import RealESRGANer | |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact | |
| from torchvision.transforms.functional import rgb_to_grayscale | |
| import spaces | |
| # Global variables for file management and image mode | |
| last_file = None | |
| img_mode = "RGBA" | |
| def realesrgan(img, model_name, denoise_strength, face_enhance, outscale): | |
| """ | |
| Real-ESRGAN function to restore (and upscale) images. | |
| Args: | |
| img (PIL.Image or numpy.ndarray): The input image. | |
| model_name (str): The name of the Real-ESRGAN model to use. | |
| denoise_strength (float): The strength of denoising for 'realesr-general-x4v3' model. | |
| face_enhance (bool): Whether to apply face enhancement using GFPGAN. | |
| outscale (int): The desired upscale factor for the output image. | |
| Returns: | |
| tuple: A tuple containing the path to the output image and its properties string. | |
| Returns None if no image is provided or an error occurs. | |
| """ | |
| if img is None: | |
| return None, "No image provided for upscaling." | |
| # Define model parameters based on the selected model_name | |
| if model_name == 'RealESRGAN_x4plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] | |
| elif model_name == 'RealESRNet_x4plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] | |
| elif model_name == 'RealESRGAN_x4plus_anime_6B': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) | |
| netscale = 4 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] | |
| elif model_name == 'RealESRGAN_x2plus': | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| netscale = 2 | |
| file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] | |
| elif model_name == 'realesr-general-x4v3': | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| netscale = 4 | |
| file_url = [ | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', | |
| 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' | |
| ] | |
| else: | |
| return None, "Invalid model name selected." | |
| # Load model weights | |
| model_path = os.path.join('weights', model_name + '.pth') | |
| if not os.path.isfile(model_path): | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| for url in file_url: | |
| model_path = load_file_from_url( | |
| url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) | |
| dni_weight = None | |
| if model_name == 'realesr-general-x4v3' and denoise_strength != 1: | |
| wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') | |
| model_path = [model_path, wdn_model_path] | |
| dni_weight = [denoise_strength, 1 - denoise_strength] | |
| # Initialize RealESRGAN upsampler | |
| upsampler = RealESRGANer( | |
| scale=netscale, | |
| model_path=model_path, | |
| dni_weight=dni_weight, | |
| model=model, | |
| tile=0, # Set to 0 for no tiling, or a positive integer for tile size | |
| tile_pad=10, | |
| pre_pad=10, | |
| half=False, | |
| gpu_id=None # Let RealESRGANer determine GPU if available | |
| ) | |
| face_enhancer = None | |
| if face_enhance: | |
| from gfpgan import GFPGANer | |
| face_enhancer = GFPGANer( | |
| model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| upscale=outscale, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=upsampler) | |
| # Convert input image to OpenCV format (BGRA) | |
| cv_img = numpy.array(img) | |
| # Ensure the image has an alpha channel if img_mode is RGBA, otherwise convert to BGR | |
| if img_mode == "RGBA": | |
| img_to_process = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA) | |
| else: | |
| img_to_process = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR) | |
| try: | |
| if face_enhance and face_enhancer: | |
| # Enhance faces and upscale background | |
| _, _, output = face_enhancer.enhance(img_to_process, has_aligned=False, only_center_face=False, paste_back=True) | |
| else: | |
| # Only upscale | |
| output, _ = upsampler.enhance(img_to_process, outscale=outscale) | |
| except RuntimeError as error: | |
| print(f'Error during upscaling: {error}') | |
| print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') | |
| return None, f"Error during upscaling: {error}. Try reducing image size or upscale factor." | |
| else: | |
| extension = 'png' if img_mode == 'RGBA' else 'jpg' | |
| out_filename = f"output_{rnd_string(8)}.{extension}" | |
| cv2.imwrite(out_filename, output) | |
| global last_file | |
| last_file = out_filename | |
| # Convert output image back to RGBA if original was RGBA, otherwise keep as is | |
| output_img_display = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) if img_mode == "RGBA" else output | |
| return out_filename, image_properties(output_img_display) | |
| def rnd_string(x): | |
| """Generates a random string of length x.""" | |
| characters = "abcdefghijklmnopqrstuvwxyz_0123456789" | |
| return "".join((random.choice(characters)) for i in range(x)) | |
| def reset(): | |
| """Resets the input and output images and their properties, and deletes the last generated file.""" | |
| global last_file | |
| if last_file and os.path.exists(last_file): | |
| print(f"Deleting {last_file} ...") | |
| os.remove(last_file) | |
| last_file = None | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| def has_transparency(img): | |
| """Checks if a PIL image has transparency.""" | |
| if img.info.get("transparency", None) is not None: | |
| return True | |
| if img.mode == "P": | |
| transparent = img.info.get("transparency", -1) | |
| for _, index in img.getcolors(): | |
| if index == transparent: | |
| return True | |
| elif img.mode == "RGBA": | |
| extrema = img.getextrema() | |
| if extrema[3][0] < 255: # Check if alpha channel has any value less than 255 (fully opaque) | |
| return True | |
| return False | |
| def image_properties(img): | |
| """ | |
| Returns the dimensions (width and height) and color mode of the input image | |
| and also sets the global img_mode variable to be used by the realesrgan function. | |
| Args: | |
| img (PIL.Image or numpy.ndarray): The input image. | |
| Returns: | |
| str: A string describing the image properties (resolution and color mode). | |
| """ | |
| global img_mode | |
| if img is None: | |
| return "No image data available." | |
| if isinstance(img, numpy.ndarray): | |
| height, width = img.shape[:2] | |
| channels = img.shape[2] if len(img.shape) > 2 else 1 | |
| # Determine img_mode based on channels for numpy array | |
| if channels == 4: | |
| img_mode = "RGBA" | |
| elif channels == 3: | |
| img_mode = "RGB" | |
| else: | |
| img_mode = "Grayscale" | |
| return f"Resolution: Width: {width}, Height: {height} | Color Mode: {img_mode}" | |
| # For PIL images (which Gradio's gr.Image(type="pil") returns) | |
| if hasattr(img, "info") and hasattr(img, "mode") and hasattr(img, "size"): | |
| if has_transparency(img): | |
| img_mode = "RGBA" | |
| else: | |
| img_mode = "RGB" | |
| return f"Resolution: Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img.mode}" # Use img.mode directly here | |
| return "Unsupported image format." | |
| def main(): | |
| """Main function to define and launch the Gradio interface.""" | |
| # Define a custom theme for a more modern look | |
| custom_theme = gr.themes.Soft( | |
| primary_hue="purple", | |
| secondary_hue="indigo", | |
| neutral_hue="gray", | |
| text_size="lg", | |
| font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], | |
| spacing_size="md", | |
| radius_size="lg", | |
| ).set( | |
| # General customizations for a sleek look | |
| button_primary_background_fill="*primary_500", | |
| button_primary_background_fill_hover="*primary_600", | |
| button_primary_text_color="*white", | |
| button_secondary_background_fill="*neutral_100", | |
| button_secondary_background_fill_hover="*neutral_200", | |
| button_secondary_text_color="*neutral_700", | |
| panel_background_fill="*neutral_50", | |
| block_background_fill="*white", | |
| block_border_width="1px", | |
| block_border_color="*neutral_200", | |
| block_shadow="*shadow_lg", # Use standard shadow variable | |
| input_background_fill="*white", | |
| input_border_color="*neutral_300", | |
| input_shadow="*shadow_sm", | |
| ) | |
| with gr.Blocks(theme=custom_theme, title="Premium Image Upscaler") as app: | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(to right, #6a11cb 0%, #2575fc 100%); color: white; border-radius: 10px; margin-bottom: 30px; box-shadow: 0 8px 16px rgba(0,0,0,0.2);"> | |
| <h1 style="font-size: 3em; margin-bottom: 10px;">✨ Elite Image Enhancement ✨</h1> | |
| <p style="font-size: 1.2em; opacity: 0.9;"> | |
| Unleash the full potential of your visuals with state-of-the-art AI upscaling. | |
| Experience unparalleled clarity and detail. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(variant="panel", elem_id="main-layout-row"): | |
| with gr.Column(scale=1, min_width=300): # Controls column | |
| with gr.Accordion("⚙️ Upscaling Parameters", open=True, elem_id="upscale-options-accordion"): | |
| gr.Markdown( | |
| """ | |
| <p style="font-size: 1.1em; color: #555;"> | |
| Fine-tune the AI to achieve your desired image quality. | |
| </p> | |
| """ | |
| ) | |
| model_name = gr.Dropdown( | |
| label="Select Real-ESRGAN Model", | |
| choices=[ | |
| "RealESRGAN_x4plus", | |
| "RealESRNet_x4plus", | |
| "RealESRGAN_x4plus_anime_6B", | |
| "RealESRGAN_x2plus", | |
| "realesr-general-x4v3" | |
| ], | |
| value="RealESRGAN_x4plus", | |
| # info="Choose the base model optimized for various image types.", # Removed 'info' | |
| elem_id="model-dropdown" | |
| ) | |
| denoise_strength = gr.Slider( | |
| label="Denoise Strength (for 'realesr-general-x4v3')", | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.5, | |
| # info="Controls the denoising level. Higher value = more denoising (only for 'realesr-general-x4v3').", # Removed 'info' | |
| elem_id="denoise-slider" | |
| ) | |
| outscale = gr.Slider( | |
| label="Output Resolution Upscale Factor", | |
| minimum=1, | |
| maximum=6, | |
| step=1, | |
| value=4, | |
| # info="Increases image resolution (e.g., 4x for 4 times wider/taller).", # Removed 'info' | |
| elem_id="outscale-slider" | |
| ) | |
| face_enhance = gr.Checkbox( | |
| label="Enable Face Enhancement (GFPGAN)", | |
| # info="Detects and enhances faces using GFPGAN alongside general upscaling.", # Removed 'info' | |
| elem_id="face-enhance-checkbox" | |
| ) | |
| with gr.Column(elem_id="action-buttons-column"): | |
| gr.Markdown("<h3 style='text-align: center; margin-top: 20px; color: #333;'>Actions</h3>") | |
| with gr.Row(): | |
| reset_btn = gr.Button("🔄 Reset All", variant="secondary", size="lg", elem_id="reset-button") | |
| upscale_btn = gr.Button("🚀 Upscale Image", variant="primary", size="lg", elem_id="upscale-button") | |
| with gr.Column(scale=2, min_width=500): # Image display column | |
| gr.Markdown( | |
| """ | |
| <h3 style="text-align: center; margin-bottom: 20px; color: #333;"> | |
| <span style="color: #6a11cb;">Input</span> vs <span style="color: #2575fc;">Output</span> | |
| </h3> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Original Image", | |
| type="pil", | |
| interactive=True, | |
| height=350, # Adjusted height for better fit in column layout | |
| elem_id="input_image_upload", | |
| # info="Drop your image here or click to upload." # Removed 'info' | |
| ) | |
| input_properties = gr.Textbox( | |
| label="Input Image Details", | |
| interactive=False, | |
| placeholder="Image properties will appear here.", | |
| elem_id="input-properties-textbox" | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Enhanced Image", | |
| interactive=False, | |
| height=350, # Adjusted height | |
| elem_id="output_image_display", | |
| # info="Your high-resolution image will be displayed here." # Removed 'info' | |
| ) | |
| output_properties = gr.Textbox( | |
| label="Output Image Details", | |
| interactive=False, | |
| placeholder="Enhanced image properties will appear here.", | |
| elem_id="output-properties-textbox" | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; padding: 15px; margin-top: 30px; background-color: #f0f2f5; border-radius: 8px; color: #666; font-size: 0.9em;"> | |
| <p>Powered by <a href="https://github.com/xinntao/Real-ESRGAN" target="_blank" style="color: #6a11cb; text-decoration: none;">Real-ESRGAN</a> and <a href="https://github.com/TencentARC/GFPGAN" target="_blank" style="color: #2575fc; text-decoration: none;">GFPGAN</a>.</p> | |
| <p>Built with ❤️ using <a href="https://gradio.app/" target="_blank" style="color: #666; text-decoration: none;">Gradio</a>.</p> | |
| </div> | |
| """ | |
| ) | |
| # Event listeners | |
| input_image.change(fn=image_properties, inputs=input_image, outputs=input_properties) | |
| upscale_btn.click( | |
| fn=realesrgan, | |
| inputs=[input_image, model_name, denoise_strength, face_enhance, outscale], | |
| outputs=[output_image, output_properties] | |
| ) | |
| reset_btn.click( | |
| fn=reset, | |
| inputs=[], | |
| outputs=[input_image, output_image, input_properties, output_properties] | |
| ) | |
| app.launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |