File size: 1,477 Bytes
7bfaca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import matplotlib.pyplot as plt
import pywt
import gradio as gr

LEVEL_THRESHOLDS = {
    1: 20,
    2: 40,
    3: 60,
    4: 80
}

def threshold_single_level(coeffs_level, threshold):
    LH, HL, HH = coeffs_level
    new_LH = np.where(np.abs(LH) < threshold, 0, LH)
    new_HL = np.where(np.abs(HL) < threshold, 0, HL)
    new_HH = np.where(np.abs(HH) < threshold, 0, HH)
    return (new_LH, new_HL, new_HH)

def wavelet_compress(img, level):
    if img.ndim == 3:
        img = img[..., :3]
        img = np.dot(img, [0.299, 0.587, 0.114])
    img = img.astype(np.uint8)
    h, w = img.shape
    threshold = LEVEL_THRESHOLDS[level]
    coeffs = pywt.wavedec2(img, "haar", level=level)
    new_coeffs = [coeffs[0]]
    for level_tuple in coeffs[1:]:
        new_tuple = threshold_single_level(level_tuple, threshold)
        new_coeffs.append(new_tuple)
    recon = pywt.waverec2(new_coeffs, "haar")
    recon = recon[:h, :w]
    recon = np.clip(recon, 0, 255).astype(np.uint8)
    return recon

def interface(img, level):
    return wavelet_compress(img, level)

demo = gr.Interface(
    fn=interface,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Slider(1, 4, step=1, value=1, label="Wavelet Level")
    ],
    outputs=gr.Image(label="Reconstructed Image"),
    title="Wavelet Image Compression (A-Version)",
    description="Higher wavelet levels use larger thresholds, resulting in stronger distortion."
)

demo.launch()