import numpy as np import matplotlib.pyplot as plt import pywt import gradio as gr LEVEL_THRESHOLDS = { 1: 20, 2: 40, 3: 60, 4: 80 } def threshold_single_level(coeffs_level, threshold): LH, HL, HH = coeffs_level new_LH = np.where(np.abs(LH) < threshold, 0, LH) new_HL = np.where(np.abs(HL) < threshold, 0, HL) new_HH = np.where(np.abs(HH) < threshold, 0, HH) return (new_LH, new_HL, new_HH) def wavelet_compress(img, level): if img.ndim == 3: img = img[..., :3] img = np.dot(img, [0.299, 0.587, 0.114]) img = img.astype(np.uint8) h, w = img.shape threshold = LEVEL_THRESHOLDS[level] coeffs = pywt.wavedec2(img, "haar", level=level) new_coeffs = [coeffs[0]] for level_tuple in coeffs[1:]: new_tuple = threshold_single_level(level_tuple, threshold) new_coeffs.append(new_tuple) recon = pywt.waverec2(new_coeffs, "haar") recon = recon[:h, :w] recon = np.clip(recon, 0, 255).astype(np.uint8) return recon def interface(img, level): return wavelet_compress(img, level) demo = gr.Interface( fn=interface, inputs=[ gr.Image(type="numpy", label="Upload Image"), gr.Slider(1, 4, step=1, value=1, label="Wavelet Level") ], outputs=gr.Image(label="Reconstructed Image"), title="Wavelet Image Compression (A-Version)", description="Higher wavelet levels use larger thresholds, resulting in stronger distortion." ) demo.launch()