throwaway74's picture
Update app.py
bc9f66e verified
import os
import uuid
import gradio as gr
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from image_gen_aux import UpscaleWithModel
# ---------------------------------
# Paths
# ---------------------------------
BASE_TMP_DIR = "/tmp/image_enhancer"
ENHANCED_DIR = os.path.join(BASE_TMP_DIR, "enhanced")
MODEL_DIR = os.path.join(BASE_TMP_DIR, "models")
os.makedirs(ENHANCED_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
# ---------------------------------
# Model configuration
# ---------------------------------
MODEL_SPECS = {
"AnimeSharp": {
"repo_id": "Kim2091/AnimeSharp",
"filename": "4x-AnimeSharp.pth",
},
"UltraSharp": {
"repo_id": "Kim2091/UltraSharp",
"filename": "4x-UltraSharp.pth",
},
"UltraMix Balanced": {
"repo_id": "LykosAI/Upscalers",
"filename": "UltraMix/4x-UltraMix_Balanced.pth",
},
}
MODEL_CACHE = {}
RATIO_MAP = {
"16:9": (16, 9),
"9:16": (9, 16),
"4:5": (4, 5),
"1:1": (1, 1),
"5:4": (5, 4),
"2:3": (2, 3),
"3:2": (3, 2),
}
MODE_CHOICES = [
"4x High Fidelity",
"8x Multi-Pass (Drift Likely)",
]
REDUCTION_CHOICES = ["Off", "80%", "85%", "90%"]
REDUCTION_DISCLAIMER = (
"Hi-Fi Output offers Post-Processing Size Reduction to further improve results. "
"Mode is entirely Optional and is defaulted at Off. Please note, this does not work "
"with Fast modes. If toggled, it will not be applied."
)
# ---------------------------------
# Helpers
# ---------------------------------
def get_model(model_name: str):
global MODEL_CACHE
if model_name in MODEL_CACHE:
return MODEL_CACHE[model_name]
spec = MODEL_SPECS[model_name]
local_path = hf_hub_download(
repo_id=spec["repo_id"],
filename=spec["filename"],
local_dir=MODEL_DIR,
local_dir_use_symlinks=False,
)
MODEL_CACHE[model_name] = UpscaleWithModel.from_pretrained(local_path).to("cuda")
return MODEL_CACHE[model_name]
def get_tile_dimensions(ratio_name: str, tile_preset: str):
long_side = int(tile_preset)
rw, rh = RATIO_MAP[ratio_name]
if rw >= rh:
tile_width = long_side
tile_height = round(long_side * rh / rw)
else:
tile_height = long_side
tile_width = round(long_side * rw / rh)
tile_width = max(2, tile_width - (tile_width % 2))
tile_height = max(2, tile_height - (tile_height % 2))
return tile_width, tile_height
def update_tile_display(ratio_name: str, tile_preset: str):
tile_width, tile_height = get_tile_dimensions(ratio_name, tile_preset)
return (
f"**Tile Width:** {tile_width}px \n"
f"**Tile Height:** {tile_height}px"
)
def format_megapixels(width: int, height: int) -> str:
return f"{(width * height) / 1_000_000:.2f} MP"
def format_file_size(num_bytes: int) -> str:
if num_bytes < 1024:
return f"{num_bytes} B"
if num_bytes < 1024 ** 2:
return f"{num_bytes / 1024:.1f} KB"
if num_bytes < 1024 ** 3:
return f"{num_bytes / (1024 ** 2):.2f} MB"
return f"{num_bytes / (1024 ** 3):.2f} GB"
def build_stats_markdown(
original_width: int,
original_height: int,
enhanced_width: int,
enhanced_height: int,
file_size_bytes: int,
export_format: str,
mode_name: str,
reduction_choice: str,
reduction_applied: bool,
model_name: str,
):
reduction_status = reduction_choice if reduction_applied else "Ignored / Not Applied"
return (
f"**Model:** {model_name} \n"
f"**Mode:** {mode_name} \n"
f"**Export Format:** {export_format} \n"
f"**Hi-Fi Output Reduction:** {reduction_status} \n\n"
f"**Original Dimensions:** {original_width} × {original_height}px \n"
f"**Original Megapixels:** {format_megapixels(original_width, original_height)} \n\n"
f"**Enhanced Dimensions:** {enhanced_width} × {enhanced_height}px \n"
f"**Enhanced Megapixels:** {format_megapixels(enhanced_width, enhanced_height)} \n\n"
f"**Saved File Size:** {format_file_size(file_size_bytes)}"
)
def reduction_factor_from_choice(choice: str):
mapping = {
"80%": 0.80,
"85%": 0.85,
"90%": 0.90,
}
return mapping.get(choice, 1.0)
def apply_output_reduction(img: Image.Image, reduction_choice: str):
factor = reduction_factor_from_choice(reduction_choice)
if factor >= 1.0:
return img
new_width = max(2, int(round(img.width * factor)))
new_height = max(2, int(round(img.height * factor)))
new_width -= new_width % 2
new_height -= new_height % 2
new_width = max(2, new_width)
new_height = max(2, new_height)
return img.resize((new_width, new_height), Image.LANCZOS)
def upscale_once(img: Image.Image, model_name: str, tile_width: int, tile_height: int):
upscaler = get_model(model_name)
out = upscaler(
img,
tiling=True,
tile_width=tile_width,
tile_height=tile_height,
)
if not isinstance(out, Image.Image):
out = Image.fromarray(out)
return out.convert("RGB")
def run_mode_pipeline(
img: Image.Image,
model_name: str,
mode_name: str,
tile_width: int,
tile_height: int,
):
if mode_name == "4x High Fidelity":
return upscale_once(img, model_name, tile_width, tile_height)
if mode_name == "8x Multi-Pass (Drift Likely)":
first = upscale_once(img, model_name, tile_width, tile_height)
second = first.resize((img.width * 8, img.height * 8), Image.LANCZOS)
return second.convert("RGB")
return upscale_once(img, model_name, tile_width, tile_height)
def save_output_image(output_img: Image.Image, export_format: str):
output_img = output_img.convert("RGB")
file_id = uuid.uuid4().hex
if export_format == "PNG":
path = os.path.join(ENHANCED_DIR, f"{file_id}.png")
output_img.save(path, format="PNG", compress_level=0)
else:
path = os.path.join(ENHANCED_DIR, f"{file_id}.tiff")
output_img.save(path, format="TIFF")
return path
# ---------------------------------
# GPU function
# ---------------------------------
@spaces.GPU
def enhance_image(
reduction_choice,
model_name,
mode_name,
ratio_name,
tile_preset,
export_format,
input_image,
):
if input_image is None:
return None, None, "No stats available yet."
original_img = Image.fromarray(input_image).convert("RGB")
original_width, original_height = original_img.size
tile_width, tile_height = get_tile_dimensions(ratio_name, tile_preset)
enhanced_img = run_mode_pipeline(
img=original_img,
model_name=model_name,
mode_name=mode_name,
tile_width=tile_width,
tile_height=tile_height,
)
reduction_applied = False
if mode_name == "4x High Fidelity" and reduction_choice != "Off":
enhanced_img = apply_output_reduction(enhanced_img, reduction_choice)
reduction_applied = True
enhanced_width, enhanced_height = enhanced_img.size
output_path = save_output_image(enhanced_img, export_format)
file_size_bytes = os.path.getsize(output_path)
stats_markdown = build_stats_markdown(
original_width=original_width,
original_height=original_height,
enhanced_width=enhanced_width,
enhanced_height=enhanced_height,
file_size_bytes=file_size_bytes,
export_format=export_format,
mode_name=mode_name,
reduction_choice=reduction_choice,
reduction_applied=reduction_applied,
model_name=model_name,
)
return enhanced_img, output_path, stats_markdown
# ---------------------------------
# UI
# ---------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Image Enhancer")
# 0. Hi-Fi Output Reduction
with gr.Group():
gr.Markdown("### Hi-Fi Output Reduction")
gr.Markdown(REDUCTION_DISCLAIMER)
reduction_choice = gr.Radio(
choices=REDUCTION_CHOICES,
value="Off",
label="Reduction Amount"
)
# 1. Model / Mode box
with gr.Group():
model_name = gr.Radio(
choices=["AnimeSharp", "UltraSharp", "UltraMix Balanced"],
value="AnimeSharp",
label="Reconstruction Model"
)
mode_name = gr.Radio(
choices=MODE_CHOICES,
value="4x High Fidelity",
label="Processing Mode"
)
# 2. Combined Tile Settings
with gr.Group():
gr.Markdown("### Tile Settings")
ratio_name = gr.Radio(
choices=["16:9", "9:16", "4:5", "1:1", "5:4", "2:3", "3:2"],
value="2:3",
label="Aspect Ratio"
)
tile_preset = gr.Radio(
choices=["512", "768", "1024"],
value="768",
label="Preset Size"
)
tile_display = gr.Markdown(
value=update_tile_display("2:3", "768")
)
# 2.5 Output Settings
with gr.Group():
gr.Markdown("### Output Settings")
export_format = gr.Radio(
choices=["PNG", "TIFF"],
value="PNG",
label="Export Format"
)
# 3. Input Image
with gr.Group():
input_image = gr.Image(
type="numpy",
label="Input Image",
height=400
)
run_button = gr.Button("Enhance Image")
# 4. Output Preview
with gr.Group():
gr.Markdown("### Output Preview")
output_preview = gr.Image(
type="pil",
label="Enhanced Preview",
height=400
)
# 5. Download box
with gr.Group():
gr.Markdown("### New Enhanced Image File")
download_file = gr.File(
label="Download new enhanced image file"
)
# Stats
with gr.Group():
gr.Markdown("### Image Stats")
stats_box = gr.Markdown(
value="No stats available yet."
)
ratio_name.change(
fn=update_tile_display,
inputs=[ratio_name, tile_preset],
outputs=tile_display
)
tile_preset.change(
fn=update_tile_display,
inputs=[ratio_name, tile_preset],
outputs=tile_display
)
run_button.click(
fn=enhance_image,
inputs=[
reduction_choice,
model_name,
mode_name,
ratio_name,
tile_preset,
export_format,
input_image,
],
outputs=[
output_preview,
download_file,
stats_box,
],
show_progress=True
)
demo.launch(
ssr_mode=False,
allowed_paths=[BASE_TMP_DIR],
)