import io import os import pickle from typing import Optional, Tuple import numpy as np from PIL import Image import streamlit as st def set_modern_page_config() -> None: st.set_page_config( page_title="Image Compressor (K-Means)", page_icon="🎨", layout="wide", initial_sidebar_state="expanded", ) # Subtle custom styling for a modern look st.markdown( """ """, unsafe_allow_html=True, ) def try_load_pretrained_pipeline(pickle_path: str) -> Optional[object]: if not os.path.exists(pickle_path): return None try: with open(pickle_path, "rb") as f: obj = pickle.load(f) return obj except Exception: return None def extract_cluster_centers_from_model(model: object) -> Optional[np.ndarray]: """Try to extract K-Means cluster centers from a variety of common objects. Returns an array of shape (k, 3) in RGB space if available. """ if model is None: return None # Direct KMeans-like object centers = getattr(model, "cluster_centers_", None) if centers is not None: return centers # sklearn Pipeline-like, last step a KMeans steps = getattr(model, "steps", None) if steps and len(steps) > 0: last_estimator = steps[-1][1] centers = getattr(last_estimator, "cluster_centers_", None) if centers is not None: return centers # Dict-like if isinstance(model, dict) and "cluster_centers_" in model: return np.asarray(model["cluster_centers_"]) return None def quantize_with_centers(image_array: np.ndarray, centers: np.ndarray) -> np.ndarray: """Map each pixel to the nearest color in centers (RGB). image_array: (H, W, 3), dtype uint8 centers: (k, 3), float or uint8 """ # Flatten to (N,3) pixels = image_array.reshape(-1, 3).astype(np.float32) centers = centers.astype(np.float32) # Compute squared distances to all centers efficiently # distances: (N, k) # Using (a-b)^2 = a^2 + b^2 - 2ab for efficiency px2 = np.sum(pixels * pixels, axis=1, keepdims=True) # (N,1) ct2 = np.sum(centers * centers, axis=1, keepdims=True).T # (1,k) dots = pixels @ centers.T # (N,k) dist2 = px2 + ct2 - 2.0 * dots labels = np.argmin(dist2, axis=1) quantized = centers[labels].astype(np.uint8) return quantized.reshape(image_array.shape) def kmeans_quantize(image_array: np.ndarray, n_colors: int, random_state: int) -> Tuple[np.ndarray, np.ndarray]: """Fit K-Means on the image pixels and return quantized image and cluster centers.""" from sklearn.cluster import KMeans pixels = image_array.reshape(-1, 3).astype(np.float32) model = KMeans(n_clusters=n_colors, n_init=4, random_state=random_state) labels = model.fit_predict(pixels) centers = model.cluster_centers_.astype(np.uint8) quantized = centers[labels].reshape(image_array.shape) return quantized, centers def compress_image( image: Image.Image, mode: str, n_colors: int, use_pretrained: bool, pretrained_centers: Optional[np.ndarray], random_state: int, max_side: int, ) -> Tuple[Image.Image, Optional[np.ndarray]]: # Resize for faster processing if needed original_size = image.size img = image.copy() img.thumbnail((max_side, max_side)) image_np = np.array(img.convert("RGB")) if mode == "Auto (use pretrained if available)" and use_pretrained and pretrained_centers is not None: quant_np = quantize_with_centers(image_np, pretrained_centers) centers = pretrained_centers.astype(np.uint8) else: quant_np, centers = kmeans_quantize(image_np, n_colors=n_colors, random_state=random_state) quant_img = Image.fromarray(quant_np, mode="RGB") # If we resized for speed, upscale back to original using nearest to preserve palette if quant_img.size != original_size: quant_img = quant_img.resize(original_size, resample=Image.NEAREST) return quant_img, centers def image_bytes(img: Image.Image, fmt: str, quality: int) -> bytes: buf = io.BytesIO() params = {} if fmt.upper() in {"JPEG", "WEBP"}: params["quality"] = int(quality) if fmt.upper() == "JPEG": params["optimize"] = True params["progressive"] = True if fmt.upper() == "PNG": # Pillow uses 0-9 for compress_level (opposite of quality). Map roughly. compress_level = int(np.clip((100 - quality) / 11, 0, 9)) params["compress_level"] = compress_level img.save(buf, format=fmt.upper(), **params) return buf.getvalue() def human_size(num_bytes: int) -> str: units = ["B", "KB", "MB", "GB"] size = float(num_bytes) for unit in units: if size < 1024.0 or unit == units[-1]: return f"{size:.2f} {unit}" size /= 1024.0 return f"{size:.2f} GB" def main() -> None: set_modern_page_config() st.markdown("## 🎨 Image Compressor — K-Means Color Quantization") st.caption("Reduce image size by limiting its color palette while keeping it visually appealing.") with st.sidebar: st.markdown("### Controls") n_colors = st.slider("Number of colors", min_value=2, max_value=64, value=16, step=1) # Use a slightly lower default quality to help reduce sizes, and default WEBP quality = st.slider("Output quality", min_value=10, max_value=100, value=85, step=1) output_format = st.selectbox("Output format", options=["PNG", "JPEG", "WEBP"], index=2) random_state = st.number_input("Random seed", min_value=0, max_value=2**31 - 1, value=42, step=1) max_side = st.slider("Process at max side (px)", min_value=256, max_value=4096, value=1024, step=128) st.caption("Tip: JPEG→PNG can increase size. Prefer JPEG/WEBP for photos.") # Always train per image — do not use any pretrained model use_pretrained = False mode = "Train per image" centers = None uploaded = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"]) if uploaded is None: st.stop() # Read original original_bytes = uploaded.read() original_img = Image.open(io.BytesIO(original_bytes)).convert("RGB") # Process/compress with st.spinner("Compressing with K-Means…"): quant_img, used_centers = compress_image( image=original_img, mode=mode, n_colors=n_colors, use_pretrained=use_pretrained, pretrained_centers=centers, random_state=random_state, max_side=max_side, ) # Prepare bytes for download compressed_bytes = image_bytes(quant_img, fmt=output_format, quality=quality) # Metrics orig_size = len(original_bytes) new_size = len(compressed_bytes) saving = orig_size - new_size saving_pct = (saving / orig_size * 100.0) if orig_size > 0 else 0.0 col_m1, col_m2, col_m3 = st.columns(3) col_m1.metric("Original size", human_size(orig_size)) col_m2.metric("Compressed size", human_size(new_size)) col_m3.metric("Saved", f"{human_size(max(saving, 0))}", delta=f"{saving_pct:.1f}%") # Preview side-by-side left, right = st.columns(2) with left: st.markdown("#### Before") st.image(original_img, use_container_width=True) with right: st.markdown("#### After") st.image(quant_img, use_container_width=True) # Palette preview if used_centers is not None and used_centers.size > 0: st.markdown("#### Extracted palette") k = used_centers.shape[0] # Create a small swatch image displaying centers swatch_h, swatch_w = 40, 40 * k swatch = np.zeros((swatch_h, swatch_w, 3), dtype=np.uint8) for i, color in enumerate(used_centers.astype(np.uint8)): swatch[:, i * 40 : (i + 1) * 40, :] = color st.image(Image.fromarray(swatch), caption=f"{k} colors", use_column_width=False) # Download file_root, _ = os.path.splitext(uploaded.name) outfile = f"{file_root}_compressed.{output_format.lower()}" st.download_button( label="Download compressed image", data=compressed_bytes, file_name=outfile, mime=Image.MIME.get(output_format.upper(), f"image/{output_format.lower()}"), ) with st.expander("Advanced details"): st.write({ "quantization_mode": mode, "n_colors_requested": n_colors, "n_colors_used": int(used_centers.shape[0]) if used_centers is not None else None, "output_format": output_format, "quality": quality, "random_state": random_state, "max_side": max_side, }) if __name__ == "__main__": # When run via `streamlit run`, a ScriptRunContext exists. If not, avoid calling # Streamlit APIs directly to prevent "missing ScriptRunContext" warnings/errors. try: from streamlit.runtime.scriptrunner import get_script_run_ctx # type: ignore if get_script_run_ctx() is not None: main() else: print("This app must be started with: streamlit run Image_Compressor/app.py") except Exception: # Older/newer Streamlit versions may not expose get_script_run_ctx. # Fallback: advise the correct launch command. print("This app must be started with: streamlit run Image_Compressor/app.py")