Spaces:
Running
Running
File size: 7,328 Bytes
8ccd726 24bace7 8ccd726 24bace7 cdda156 24bace7 8ccd726 24bace7 6a79ce3 24bace7 6a79ce3 24bace7 8ccd726 24bace7 8ccd726 24bace7 8ccd726 24bace7 6a79ce3 24bace7 8ccd726 6a79ce3 8ccd726 24bace7 6a79ce3 24bace7 8ccd726 24bace7 6a79ce3 24bace7 6a79ce3 24bace7 6a79ce3 24bace7 6a79ce3 24bace7 8ccd726 cdda156 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import sys
import tempfile
import cv2
import torch
import gradio as gr
from torchvision.transforms import functional
# --- PATCH FOR COMPATIBILITY ---
sys.modules["torchvision.transforms.functional_tensor"] = functional
# --- EMBEDDED CSS FOR STYLING ---
CSS_STYLING = """
:root {
--primary: hsl(265, 100%, 61%); /* Accent Purple */
--secondary: hsl(327, 100%, 72%); /* Accent Pink */
--blue: hsl(204, 100%, 72%); /* Accent Blue */
--background-darker: hsl(240, 14%, 3%);
--background-dark: hsl(240, 14%, 5%);
--card-background: hsl(240, 10%, 7%);
--light-text: hsl(240, 5%, 90%);
--muted-text: hsl(240, 4%, 65%);
--error-text: hsl(0, 100%, 74%);
--card-border: hsl(253, 100%, 72%, 0.15);
--input-background-fill: var(--card-background) !important;
--input-border-color: var(--card-border) !important;
--input-label-color: var(--light-text) !important;
}
.gradio-container {
background: var(--background-dark);
font-family: 'Inter', sans-serif;
}
#main-title {
color: var(--light-text);
text-align: center;
font-size: 2.5rem !important;
font-weight: 900;
}
#main-subtitle {
color: var(--muted-text);
text-align: center;
font-size: 1rem !important;
margin-top: -15px;
margin-bottom: 20px;
}
#submit-button {
background: linear-gradient(135deg, var(--primary), var(--secondary));
color: white;
font-weight: bold;
border-radius: 8px !important;
transition: all 0.3s ease;
}
#submit-button:hover {
box-shadow: 0px 4px 15px rgba(124, 58, 237, 0.4); /* Subtle purple shadow */
transform: translateY(-2px);
}
.gr-image {
border: 1px solid var(--card-border) !important;
border-radius: 12px !important;
min-height: 300px;
}
input[type="range"]::-webkit-slider-thumb {
background: var(--primary) !important;
}
input[type="range"]::-moz-range-thumb {
background: var(--primary) !important;
}
.gr-radio > div {
color: var(--light-text) !important;
}
"""
# --- DOWNLOAD HELPER FUNCTIONS ---
def download_file(url, dir_path, file_name):
"""Downloads a file if it doesn't exist."""
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, file_name)
if not os.path.exists(file_path):
print(f"Downloading {file_name}...")
try:
os.system(f"wget {url} -O {file_path}")
print("Download complete.")
except Exception as e:
print(f"Error downloading {file_name}: {e}")
return file_path
# --- DOWNLOAD MODELS AND EXAMPLES ---
print("Checking for required files...")
models_dir = 'models'
download_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', models_dir, 'realesr-general-x4v3.pth')
download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', models_dir, 'GFPGANv1.4.pth')
download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', models_dir, 'RestoreFormer.pth')
examples_dir = 'examples'
example1_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/10045.png', examples_dir, 'example1.png')
example2_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/Blake_Lively.jpg', examples_dir, 'example2.jpg')
# --- LOAD MODELS INTO MEMORY ---
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from gfpgan.utils import GFPGANer
from realesrgan.utils import RealESRGANer
bg_upsampler = None
try:
model_path = os.path.join(models_dir, 'realesr-general-x4v3.pth')
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
half = torch.cuda.is_available()
bg_upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
print("Background Upsampler (Real-ESRGAN) loaded for 4x enhancement.")
except Exception as e:
print(f"Error loading background upsampler: {e}. The app may not work correctly.")
# --- CORE IMAGE PROCESSING FUNCTION ---
def upscale_image(img_path, version):
"""Enhance an image using GFPGAN and Real-ESRGAN with a fixed 4x upscale."""
if not img_path:
raise gr.Error("Please upload an image.")
if not bg_upsampler:
raise gr.Error("Background upsampler not loaded. Cannot proceed.")
try:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None: raise RuntimeError("Failed to read image.")
has_alpha = img.shape[2] == 4
face_enhancer = GFPGANer(
model_path=os.path.join(models_dir, f'{version}.pth'),
upscale=2, # Native GFPGAN upscale factor
arch='RestoreFormer' if version == 'RestoreFormer' else 'clean',
channel_multiplier=2,
bg_upsampler=bg_upsampler # Real-ESRGAN used for 4x background
)
# This will produce a 4x enhanced image because the bg_upsampler is 4x
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
ext = 'png' if has_alpha else 'jpg'
# Save to a temporary file for download
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{ext}') as temp_file:
cv2.imwrite(temp_file.name, cv2.cvtColor(output_rgb, cv2.COLOR_RGB2BGR))
return output_rgb, temp_file.name
except Exception as error:
print(f"Error processing image: {error}")
raise gr.Error(f"An error occurred: {error}")
# --- GRADIO UI LAYOUT ---
with gr.Blocks(css=CSS_STYLING, theme=gr.themes.Base()) as demo:
gr.Markdown("<h1 id='main-title'>NeuraVision AI Image Upscaler</h1>", elem_id="main-title")
gr.Markdown("<p id='main-subtitle'>Enhance old, blurry, and low-resolution photos with AI (Fixed 4x Upscale).</p>", elem_id="main-subtitle")
with gr.Row(variant="panel"):
# LEFT COLUMN (INPUT & CONTROLS)
with gr.Column(scale=1):
input_image = gr.Image(type="filepath", label="Upload Image")
version = gr.Radio(
['GFPGANv1.4', 'RestoreFormer'], value='GFPGANv1.4',
label='AI Model', info="v1.4 for general use. RestoreFormer for old photos."
)
submit_btn = gr.Button("Enhance Image", variant="primary", elem_id="submit-button")
gr.Examples(
examples=[[example1_path, "RestoreFormer"], [example2_path, "GFPGANv1.4"]],
inputs=[input_image, version],
label="Click an example to start"
)
# RIGHT COLUMN (OUTPUT)
with gr.Column(scale=1):
output_image = gr.Image(type="numpy", label="Enhanced Result", interactive=False)
download_button = gr.File(label="Download Image", interactive=False)
# --- BUTTON & EVENT HANDLING ---
submit_btn.click(
fn=upscale_image,
inputs=[input_image, version],
outputs=[output_image, download_button]
)
input_image.clear(lambda: (None, None), None, [output_image, download_button])
if __name__ == "__main__":
demo.queue()
demo.launch(share=True) |