csabhay's picture
Use MiniBatchKmeans for faster compression using random pixel sample for centroid fitting
6b9ba41
Raw
History Blame Contribute Delete
2.58 kB
import gradio as gr
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.utils import shuffle
from PIL import Image
import os
import tempfile
from io import BytesIO
def compress_kmeans(image: np.ndarray, k: int, random_state: 42) -> np.ndarray:
"""
Faster K‑Means compression using a random pixel sample for centroid fitting.
"""
pixels = image.reshape(-1, 3).astype(np.float32)
sample_size = min(20000, len(pixels))
sample = shuffle(pixels, random_state=random_state)[:sample_size]
model = MiniBatchKMeans(
n_clusters=k,
batch_size=min(1024, sample_size),
n_init='auto',
max_iter=50,
random_state=random_state,
verbose=0,
)
model.fit(sample)
labels = model.predict(pixels)
centres = model.cluster_centers_.astype(np.uint8)
compressed_pixels = centres[labels]
return compressed_pixels.reshape(image.shape)
def process(filepath, k):
if filepath is None:
return None, "Please upload an image.", None
# Load original
orig = np.array(Image.open(filepath).convert('RGB'))
# Compress
comp = compress_kmeans(orig, k, random_state=42)
def get_size(img):
with BytesIO() as buf:
Image.fromarray(img).save(buf, format='PNG')
return len(buf.getvalue())
orig_size = get_size(orig)
comp_size = get_size(comp)
ratio = orig_size / comp_size
saved = (1 - comp_size/orig_size) * 100
# PSNR
mse = np.mean((orig.astype(float) - comp.astype(float)) ** 2)
psnr = 20 * np.log10(255.0 / np.sqrt(mse)) if mse > 0 else float('inf')
stats = (f"**Original:** {orig_size/1024:.1f} KB \n"
f"**Compressed:** {comp_size/1024:.1f} KB \n"
f"**Compression ratio:** {ratio:.1f}x \n"
f"**Space saved:** {saved:.1f}% \n"
f"**PSNR:** {psnr:.1f} dB")
temp = tempfile.mkdtemp()
out_path = os.path.join(temp, 'compressed.png')
Image.fromarray(comp).save(out_path)
return comp, stats, out_path
iface = gr.Interface(
fn=process,
inputs=[
gr.Image(type='filepath', label='Upload an Image'),
gr.Slider(minimum=2, maximum=64, step=2, value=16, label='Number of colours (k)')
],
outputs=[
gr.Image(label='Compressed Image'),
gr.Markdown(label='Compression Statistics'),
gr.File(label='Download Compressed Image')
],
title='Image Compression with K‑Means',
description='Reduce the number of colours in an image using K-Means clustering.'
)
iface.launch()