BiliSakura's picture
Update app.py
d026ddf verified
from __future__ import annotations
from typing import Callable, Dict, List
import gradio as gr
import numpy as np
from PIL import Image
WaveletFn = Callable[[np.ndarray], Dict[str, np.ndarray]]
COMPONENT_ORDER: List[str] = ["LL", "LH", "HL", "HH"]
def _ensure_even(image: Image.Image) -> Image.Image:
width, height = image.size
even_width = width - (width % 2)
even_height = height - (height % 2)
if (even_width, even_height) != (width, height):
image = image.crop((0, 0, even_width, even_height))
return image
# 1. Renamed and updated to handle RGB
def _prepare_image(image: Image.Image) -> np.ndarray:
# Convert to RGB if necessary (e.g. RGBA or Grayscale input)
if image.mode != "RGB":
image = image.convert("RGB")
image = _ensure_even(image)
width, height = image.size
if width < 2 or height < 2:
raise gr.Error("Image must be at least 2x2 pixels after cropping.")
return np.asarray(image, dtype=np.float32)
# 2. Updated to support 3D arrays (RGB)
def _normalize_component(component: np.ndarray) -> np.ndarray:
if component.ndim == 3:
# Normalize each channel independently to maximize visibility
normalized = np.zeros_like(component)
for i in range(3):
channel = component[:, :, i]
min_val = float(channel.min())
max_val = float(channel.max())
if max_val - min_val < 1e-8:
continue
normalized[:, :, i] = (channel - min_val) / (max_val - min_val)
return (normalized * 255).clip(0, 255).astype(np.uint8)
else:
min_value = float(component.min())
max_value = float(component.max())
if max_value - min_value < 1e-8:
return np.zeros_like(component, dtype=np.uint8)
normalized = (component - min_value) / (max_value - min_value)
return (normalized * 255).clip(0, 255).astype(np.uint8)
def haar_wavelet_components(image_array: np.ndarray) -> Dict[str, np.ndarray]:
# NumPy broadcasting handles both 2D and 3D arrays automatically
a = image_array[0::2, 0::2]
b = image_array[0::2, 1::2]
c = image_array[1::2, 0::2]
d = image_array[1::2, 1::2]
ll = (a + b + c + d) / 2.0
lh = (-a - b + c + d) / 2.0
hl = (-a + b - c + d) / 2.0
hh = (a - b - c + d) / 2.0
return {"LL": ll, "LH": lh, "HL": hl, "HH": hh}
WAVELET_METHODS: Dict[str, WaveletFn] = {"Haar": haar_wavelet_components}
def compute_wavelet(
image: Image.Image | None, method_name: str
) -> tuple[Image.Image | None, Image.Image | None, Image.Image | None, Image.Image | None]:
if image is None:
return (None, None, None, None)
method = WAVELET_METHODS.get(method_name)
if method is None:
raise gr.Error(f"Unknown wavelet method: {method_name}")
img_array = _prepare_image(image) # Changed from grayscale
components = method(img_array)
outputs: List[Image.Image] = []
for key in COMPONENT_ORDER:
component = components[key]
normalized = _normalize_component(component)
# 3. Changed mode to RGB
outputs.append(Image.fromarray(normalized, mode="RGB"))
return tuple(outputs)
def build_demo() -> gr.Blocks:
with gr.Blocks(title="Visualize Your Wavelet") as demo:
gr.Markdown("## Visualize Your Wavelet")
gr.Markdown(
"Upload an image to view its Haar wavelet components. "
"Images are cropped to even dimensions for the transform."
)
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
method = gr.Dropdown(
choices=list(WAVELET_METHODS.keys()),
value="Haar",
label="Wavelet Method",
)
run_button = gr.Button("Compute Wavelet")
with gr.Row():
ll_image = gr.Image(label="LL (Approximation)")
lh_image = gr.Image(label="LH (Vertical Details)")
with gr.Row():
hl_image = gr.Image(label="HL (Horizontal Details)")
hh_image = gr.Image(label="HH (Diagonal Details)")
run_button.click(
fn=compute_wavelet,
inputs=[input_image, method],
outputs=[ll_image, lh_image, hl_image, hh_image],
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()