File size: 1,477 Bytes
7bfaca7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import numpy as np
import matplotlib.pyplot as plt
import pywt
import gradio as gr
LEVEL_THRESHOLDS = {
1: 20,
2: 40,
3: 60,
4: 80
}
def threshold_single_level(coeffs_level, threshold):
LH, HL, HH = coeffs_level
new_LH = np.where(np.abs(LH) < threshold, 0, LH)
new_HL = np.where(np.abs(HL) < threshold, 0, HL)
new_HH = np.where(np.abs(HH) < threshold, 0, HH)
return (new_LH, new_HL, new_HH)
def wavelet_compress(img, level):
if img.ndim == 3:
img = img[..., :3]
img = np.dot(img, [0.299, 0.587, 0.114])
img = img.astype(np.uint8)
h, w = img.shape
threshold = LEVEL_THRESHOLDS[level]
coeffs = pywt.wavedec2(img, "haar", level=level)
new_coeffs = [coeffs[0]]
for level_tuple in coeffs[1:]:
new_tuple = threshold_single_level(level_tuple, threshold)
new_coeffs.append(new_tuple)
recon = pywt.waverec2(new_coeffs, "haar")
recon = recon[:h, :w]
recon = np.clip(recon, 0, 255).astype(np.uint8)
return recon
def interface(img, level):
return wavelet_compress(img, level)
demo = gr.Interface(
fn=interface,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(1, 4, step=1, value=1, label="Wavelet Level")
],
outputs=gr.Image(label="Reconstructed Image"),
title="Wavelet Image Compression (A-Version)",
description="Higher wavelet levels use larger thresholds, resulting in stronger distortion."
)
demo.launch() |