import os import uuid import gradio as gr import spaces from PIL import Image from huggingface_hub import hf_hub_download from image_gen_aux import UpscaleWithModel # --------------------------------- # Paths # --------------------------------- BASE_TMP_DIR = "/tmp/image_enhancer" ENHANCED_DIR = os.path.join(BASE_TMP_DIR, "enhanced") MODEL_DIR = os.path.join(BASE_TMP_DIR, "models") os.makedirs(ENHANCED_DIR, exist_ok=True) os.makedirs(MODEL_DIR, exist_ok=True) # --------------------------------- # Model configuration # --------------------------------- MODEL_SPECS = { "AnimeSharp": { "repo_id": "Kim2091/AnimeSharp", "filename": "4x-AnimeSharp.pth", }, "UltraSharp": { "repo_id": "Kim2091/UltraSharp", "filename": "4x-UltraSharp.pth", }, "UltraMix Balanced": { "repo_id": "LykosAI/Upscalers", "filename": "UltraMix/4x-UltraMix_Balanced.pth", }, } MODEL_CACHE = {} RATIO_MAP = { "16:9": (16, 9), "9:16": (9, 16), "4:5": (4, 5), "1:1": (1, 1), "5:4": (5, 4), "2:3": (2, 3), "3:2": (3, 2), } MODE_CHOICES = [ "4x High Fidelity", "8x Multi-Pass (Drift Likely)", ] REDUCTION_CHOICES = ["Off", "80%", "85%", "90%"] REDUCTION_DISCLAIMER = ( "Hi-Fi Output offers Post-Processing Size Reduction to further improve results. " "Mode is entirely Optional and is defaulted at Off. Please note, this does not work " "with Fast modes. If toggled, it will not be applied." ) # --------------------------------- # Helpers # --------------------------------- def get_model(model_name: str): global MODEL_CACHE if model_name in MODEL_CACHE: return MODEL_CACHE[model_name] spec = MODEL_SPECS[model_name] local_path = hf_hub_download( repo_id=spec["repo_id"], filename=spec["filename"], local_dir=MODEL_DIR, local_dir_use_symlinks=False, ) MODEL_CACHE[model_name] = UpscaleWithModel.from_pretrained(local_path).to("cuda") return MODEL_CACHE[model_name] def get_tile_dimensions(ratio_name: str, tile_preset: str): long_side = int(tile_preset) rw, rh = RATIO_MAP[ratio_name] if rw >= rh: tile_width = long_side tile_height = round(long_side * rh / rw) else: tile_height = long_side tile_width = round(long_side * rw / rh) tile_width = max(2, tile_width - (tile_width % 2)) tile_height = max(2, tile_height - (tile_height % 2)) return tile_width, tile_height def update_tile_display(ratio_name: str, tile_preset: str): tile_width, tile_height = get_tile_dimensions(ratio_name, tile_preset) return ( f"**Tile Width:** {tile_width}px \n" f"**Tile Height:** {tile_height}px" ) def format_megapixels(width: int, height: int) -> str: return f"{(width * height) / 1_000_000:.2f} MP" def format_file_size(num_bytes: int) -> str: if num_bytes < 1024: return f"{num_bytes} B" if num_bytes < 1024 ** 2: return f"{num_bytes / 1024:.1f} KB" if num_bytes < 1024 ** 3: return f"{num_bytes / (1024 ** 2):.2f} MB" return f"{num_bytes / (1024 ** 3):.2f} GB" def build_stats_markdown( original_width: int, original_height: int, enhanced_width: int, enhanced_height: int, file_size_bytes: int, export_format: str, mode_name: str, reduction_choice: str, reduction_applied: bool, model_name: str, ): reduction_status = reduction_choice if reduction_applied else "Ignored / Not Applied" return ( f"**Model:** {model_name} \n" f"**Mode:** {mode_name} \n" f"**Export Format:** {export_format} \n" f"**Hi-Fi Output Reduction:** {reduction_status} \n\n" f"**Original Dimensions:** {original_width} × {original_height}px \n" f"**Original Megapixels:** {format_megapixels(original_width, original_height)} \n\n" f"**Enhanced Dimensions:** {enhanced_width} × {enhanced_height}px \n" f"**Enhanced Megapixels:** {format_megapixels(enhanced_width, enhanced_height)} \n\n" f"**Saved File Size:** {format_file_size(file_size_bytes)}" ) def reduction_factor_from_choice(choice: str): mapping = { "80%": 0.80, "85%": 0.85, "90%": 0.90, } return mapping.get(choice, 1.0) def apply_output_reduction(img: Image.Image, reduction_choice: str): factor = reduction_factor_from_choice(reduction_choice) if factor >= 1.0: return img new_width = max(2, int(round(img.width * factor))) new_height = max(2, int(round(img.height * factor))) new_width -= new_width % 2 new_height -= new_height % 2 new_width = max(2, new_width) new_height = max(2, new_height) return img.resize((new_width, new_height), Image.LANCZOS) def upscale_once(img: Image.Image, model_name: str, tile_width: int, tile_height: int): upscaler = get_model(model_name) out = upscaler( img, tiling=True, tile_width=tile_width, tile_height=tile_height, ) if not isinstance(out, Image.Image): out = Image.fromarray(out) return out.convert("RGB") def run_mode_pipeline( img: Image.Image, model_name: str, mode_name: str, tile_width: int, tile_height: int, ): if mode_name == "4x High Fidelity": return upscale_once(img, model_name, tile_width, tile_height) if mode_name == "8x Multi-Pass (Drift Likely)": first = upscale_once(img, model_name, tile_width, tile_height) second = first.resize((img.width * 8, img.height * 8), Image.LANCZOS) return second.convert("RGB") return upscale_once(img, model_name, tile_width, tile_height) def save_output_image(output_img: Image.Image, export_format: str): output_img = output_img.convert("RGB") file_id = uuid.uuid4().hex if export_format == "PNG": path = os.path.join(ENHANCED_DIR, f"{file_id}.png") output_img.save(path, format="PNG", compress_level=0) else: path = os.path.join(ENHANCED_DIR, f"{file_id}.tiff") output_img.save(path, format="TIFF") return path # --------------------------------- # GPU function # --------------------------------- @spaces.GPU def enhance_image( reduction_choice, model_name, mode_name, ratio_name, tile_preset, export_format, input_image, ): if input_image is None: return None, None, "No stats available yet." original_img = Image.fromarray(input_image).convert("RGB") original_width, original_height = original_img.size tile_width, tile_height = get_tile_dimensions(ratio_name, tile_preset) enhanced_img = run_mode_pipeline( img=original_img, model_name=model_name, mode_name=mode_name, tile_width=tile_width, tile_height=tile_height, ) reduction_applied = False if mode_name == "4x High Fidelity" and reduction_choice != "Off": enhanced_img = apply_output_reduction(enhanced_img, reduction_choice) reduction_applied = True enhanced_width, enhanced_height = enhanced_img.size output_path = save_output_image(enhanced_img, export_format) file_size_bytes = os.path.getsize(output_path) stats_markdown = build_stats_markdown( original_width=original_width, original_height=original_height, enhanced_width=enhanced_width, enhanced_height=enhanced_height, file_size_bytes=file_size_bytes, export_format=export_format, mode_name=mode_name, reduction_choice=reduction_choice, reduction_applied=reduction_applied, model_name=model_name, ) return enhanced_img, output_path, stats_markdown # --------------------------------- # UI # --------------------------------- with gr.Blocks() as demo: gr.Markdown("# Image Enhancer") # 0. Hi-Fi Output Reduction with gr.Group(): gr.Markdown("### Hi-Fi Output Reduction") gr.Markdown(REDUCTION_DISCLAIMER) reduction_choice = gr.Radio( choices=REDUCTION_CHOICES, value="Off", label="Reduction Amount" ) # 1. Model / Mode box with gr.Group(): model_name = gr.Radio( choices=["AnimeSharp", "UltraSharp", "UltraMix Balanced"], value="AnimeSharp", label="Reconstruction Model" ) mode_name = gr.Radio( choices=MODE_CHOICES, value="4x High Fidelity", label="Processing Mode" ) # 2. Combined Tile Settings with gr.Group(): gr.Markdown("### Tile Settings") ratio_name = gr.Radio( choices=["16:9", "9:16", "4:5", "1:1", "5:4", "2:3", "3:2"], value="2:3", label="Aspect Ratio" ) tile_preset = gr.Radio( choices=["512", "768", "1024"], value="768", label="Preset Size" ) tile_display = gr.Markdown( value=update_tile_display("2:3", "768") ) # 2.5 Output Settings with gr.Group(): gr.Markdown("### Output Settings") export_format = gr.Radio( choices=["PNG", "TIFF"], value="PNG", label="Export Format" ) # 3. Input Image with gr.Group(): input_image = gr.Image( type="numpy", label="Input Image", height=400 ) run_button = gr.Button("Enhance Image") # 4. Output Preview with gr.Group(): gr.Markdown("### Output Preview") output_preview = gr.Image( type="pil", label="Enhanced Preview", height=400 ) # 5. Download box with gr.Group(): gr.Markdown("### New Enhanced Image File") download_file = gr.File( label="Download new enhanced image file" ) # Stats with gr.Group(): gr.Markdown("### Image Stats") stats_box = gr.Markdown( value="No stats available yet." ) ratio_name.change( fn=update_tile_display, inputs=[ratio_name, tile_preset], outputs=tile_display ) tile_preset.change( fn=update_tile_display, inputs=[ratio_name, tile_preset], outputs=tile_display ) run_button.click( fn=enhance_image, inputs=[ reduction_choice, model_name, mode_name, ratio_name, tile_preset, export_format, input_image, ], outputs=[ output_preview, download_file, stats_box, ], show_progress=True ) demo.launch( ssr_mode=False, allowed_paths=[BASE_TMP_DIR], )