ImageEnhance / app.py
endrol's picture
add dceplus model
96f2754
import gradio as gr
import requests
import base64
import io
import time
import os
from PIL import Image
SR_API = os.getenv("sr_api")
ENHANCE_API = os.getenv("enhance_api")
ZERODCE_API = os.getenv("zerodce_api")
ZERODCE_PLUS_API = os.getenv("zerodce_plus_api")
# Get API endpoints from environment
API_ENDPOINTS = {
"SR API": SR_API,
"ENHANCE API": ENHANCE_API,
"ZERODCE API": ZERODCE_API,
"ZERODCE++ API": ZERODCE_PLUS_API
}
def apply_super_resolution(input_image, scale_factor, tile_size, api_choice):
"""Apply super-resolution to input image"""
if input_image is None:
return None, "❌ Please upload an image first", gr.update(visible=False)
# Select API endpoint based on user choice
api_endpoint = API_ENDPOINTS[api_choice]
if not api_endpoint:
return None, f"❌ {api_choice} endpoint not configured", gr.update(visible=False)
try:
# Convert PIL image to bytes
img_buffer = io.BytesIO()
input_image.save(img_buffer, format='PNG')
img_bytes = img_buffer.getvalue()
# Call super-resolution API
response = requests.post(
f"{api_endpoint}/invocations",
headers={"Content-Type": "application/octet-stream"},
data=img_bytes,
params={
"model_name": "RealESRGAN_x4plus",
"outscale": scale_factor,
"tile": tile_size,
"fp32": False
},
timeout=300,
verify=False # Disable SSL verification for testing
)
if response.status_code == 200:
result = response.json()
# Decode base64 result
enhanced_data = base64.b64decode(result["prediction"])
enhanced_image = Image.open(io.BytesIO(enhanced_data))
# Create download file
timestamp = int(time.time())
download_path = f"enhanced_image_{timestamp}.png"
enhanced_image.save(download_path, format='PNG')
status = f"βœ… Enhancement successful using {api_choice}!\n"
status += f"Model: {result.get('model', 'RealESRGAN_x4plus')}\n"
status += f"Scale: {result.get('outscale', 4.0)}x\n"
status += f"Input: {result.get('input_img_width', 0)}x{result.get('input_img_height', 0)}\n"
status += f"Output: {result.get('output_img_width', 0)}x{result.get('output_img_height', 0)}"
status += f"process_time: {result.get('upscaling_time', 0)}"
# Return tuple for ImageSlider: [original, enhanced]
return (input_image, enhanced_image), status, gr.update(visible=True, value=download_path)
else:
return None, f"❌ API Error: {response.status_code}\n{response.text}", gr.update(visible=False)
except Exception as e:
return None, f"❌ Error: {str(e)}", gr.update(visible=False)
def main():
with gr.Blocks(title="Image Enhancement App") as demo:
gr.Markdown("# πŸš€ Image Enhancement App")
gr.Markdown("Upload an image and enhance it with AI-powered super-resolution")
# Row 1: Upload and Controls
with gr.Row():
with gr.Column(scale=3):
input_image = gr.Image(
label="πŸ“€ Upload Image",
type="pil",
height=300
)
with gr.Column(scale=1):
gr.Markdown("### Enhancement Settings")
api_dropdown = gr.Dropdown(
choices=["SR API", "ENHANCE API", "ZERODCE API", "ZERODCE++ API"],
value="SR API",
label="API Choice",
info="Choose which enhancement API to use"
)
scale_dropdown = gr.Dropdown(
choices=[1, 2, 4],
value=4,
label="Scale Factor",
info="How much to upscale the image"
)
tile_size = gr.Number(
value=0,
label="Tile Size",
info="Tile size for the image"
)
enhance_button = gr.Button(
"✨ Enhance Image",
variant="primary",
size="lg"
)
status_text = gr.Textbox(
label="Status",
lines=6,
value="Ready to enhance images!",
interactive=False
)
# Row 2: Before/After Comparison with Image Slider
with gr.Row():
gr.Markdown("### πŸ“Š Before vs After Comparison")
with gr.Row():
image_slider = gr.ImageSlider(
label="Original vs Enhanced",
height=500,
interactive=False
)
# Download button
with gr.Row():
download_button = gr.DownloadButton(
"πŸ“₯ Download Enhanced Image",
visible=False,
size="lg"
)
# Event handlers
enhance_button.click(
fn=apply_super_resolution,
inputs=[input_image, scale_dropdown, tile_size, api_dropdown],
outputs=[image_slider, status_text, download_button],
show_progress=True
)
# Clear results when new image is uploaded
input_image.change(
fn=lambda: (None, "Image uploaded! Ready to enhance.", gr.update(visible=False)),
outputs=[image_slider, status_text, download_button]
)
# Launch the app
demo.queue(default_concurrency_limit=3, max_size=10).launch()
if __name__ == "__main__":
main()