Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import pickle | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import streamlit as st | |
| def set_modern_page_config() -> None: | |
| st.set_page_config( | |
| page_title="Image Compressor (K-Means)", | |
| page_icon="🎨", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # Subtle custom styling for a modern look | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Hide Streamlit default header/footer */ | |
| header[data-testid="stHeader"] { display: none; } | |
| footer { visibility: hidden; } | |
| /* Card-like containers */ | |
| .block-container { padding-top: 2rem; } | |
| div[data-testid="stSidebar"] { backdrop-filter: blur(6px); } | |
| /* Buttons */ | |
| .stButton>button { | |
| border-radius: 10px; | |
| background: linear-gradient(135deg, #6B73FF 0%, #000DFF 100%); | |
| color: white; | |
| border: none; | |
| box-shadow: 0 8px 24px rgba(0, 13, 255, 0.24); | |
| } | |
| .stDownloadButton>button { | |
| border-radius: 10px; | |
| } | |
| /* Metrics */ | |
| [data-testid="stMetricDelta"] { font-weight: 600; } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def try_load_pretrained_pipeline(pickle_path: str) -> Optional[object]: | |
| if not os.path.exists(pickle_path): | |
| return None | |
| try: | |
| with open(pickle_path, "rb") as f: | |
| obj = pickle.load(f) | |
| return obj | |
| except Exception: | |
| return None | |
| def extract_cluster_centers_from_model(model: object) -> Optional[np.ndarray]: | |
| """Try to extract K-Means cluster centers from a variety of common objects. | |
| Returns an array of shape (k, 3) in RGB space if available. | |
| """ | |
| if model is None: | |
| return None | |
| # Direct KMeans-like object | |
| centers = getattr(model, "cluster_centers_", None) | |
| if centers is not None: | |
| return centers | |
| # sklearn Pipeline-like, last step a KMeans | |
| steps = getattr(model, "steps", None) | |
| if steps and len(steps) > 0: | |
| last_estimator = steps[-1][1] | |
| centers = getattr(last_estimator, "cluster_centers_", None) | |
| if centers is not None: | |
| return centers | |
| # Dict-like | |
| if isinstance(model, dict) and "cluster_centers_" in model: | |
| return np.asarray(model["cluster_centers_"]) | |
| return None | |
| def quantize_with_centers(image_array: np.ndarray, centers: np.ndarray) -> np.ndarray: | |
| """Map each pixel to the nearest color in centers (RGB). | |
| image_array: (H, W, 3), dtype uint8 | |
| centers: (k, 3), float or uint8 | |
| """ | |
| # Flatten to (N,3) | |
| pixels = image_array.reshape(-1, 3).astype(np.float32) | |
| centers = centers.astype(np.float32) | |
| # Compute squared distances to all centers efficiently | |
| # distances: (N, k) | |
| # Using (a-b)^2 = a^2 + b^2 - 2ab for efficiency | |
| px2 = np.sum(pixels * pixels, axis=1, keepdims=True) # (N,1) | |
| ct2 = np.sum(centers * centers, axis=1, keepdims=True).T # (1,k) | |
| dots = pixels @ centers.T # (N,k) | |
| dist2 = px2 + ct2 - 2.0 * dots | |
| labels = np.argmin(dist2, axis=1) | |
| quantized = centers[labels].astype(np.uint8) | |
| return quantized.reshape(image_array.shape) | |
| def kmeans_quantize(image_array: np.ndarray, n_colors: int, random_state: int) -> Tuple[np.ndarray, np.ndarray]: | |
| """Fit K-Means on the image pixels and return quantized image and cluster centers.""" | |
| from sklearn.cluster import KMeans | |
| pixels = image_array.reshape(-1, 3).astype(np.float32) | |
| model = KMeans(n_clusters=n_colors, n_init=4, random_state=random_state) | |
| labels = model.fit_predict(pixels) | |
| centers = model.cluster_centers_.astype(np.uint8) | |
| quantized = centers[labels].reshape(image_array.shape) | |
| return quantized, centers | |
| def compress_image( | |
| image: Image.Image, | |
| mode: str, | |
| n_colors: int, | |
| use_pretrained: bool, | |
| pretrained_centers: Optional[np.ndarray], | |
| random_state: int, | |
| max_side: int, | |
| ) -> Tuple[Image.Image, Optional[np.ndarray]]: | |
| # Resize for faster processing if needed | |
| original_size = image.size | |
| img = image.copy() | |
| img.thumbnail((max_side, max_side)) | |
| image_np = np.array(img.convert("RGB")) | |
| if mode == "Auto (use pretrained if available)" and use_pretrained and pretrained_centers is not None: | |
| quant_np = quantize_with_centers(image_np, pretrained_centers) | |
| centers = pretrained_centers.astype(np.uint8) | |
| else: | |
| quant_np, centers = kmeans_quantize(image_np, n_colors=n_colors, random_state=random_state) | |
| quant_img = Image.fromarray(quant_np, mode="RGB") | |
| # If we resized for speed, upscale back to original using nearest to preserve palette | |
| if quant_img.size != original_size: | |
| quant_img = quant_img.resize(original_size, resample=Image.NEAREST) | |
| return quant_img, centers | |
| def image_bytes(img: Image.Image, fmt: str, quality: int) -> bytes: | |
| buf = io.BytesIO() | |
| params = {} | |
| if fmt.upper() in {"JPEG", "WEBP"}: | |
| params["quality"] = int(quality) | |
| if fmt.upper() == "JPEG": | |
| params["optimize"] = True | |
| params["progressive"] = True | |
| if fmt.upper() == "PNG": | |
| # Pillow uses 0-9 for compress_level (opposite of quality). Map roughly. | |
| compress_level = int(np.clip((100 - quality) / 11, 0, 9)) | |
| params["compress_level"] = compress_level | |
| img.save(buf, format=fmt.upper(), **params) | |
| return buf.getvalue() | |
| def human_size(num_bytes: int) -> str: | |
| units = ["B", "KB", "MB", "GB"] | |
| size = float(num_bytes) | |
| for unit in units: | |
| if size < 1024.0 or unit == units[-1]: | |
| return f"{size:.2f} {unit}" | |
| size /= 1024.0 | |
| return f"{size:.2f} GB" | |
| def main() -> None: | |
| set_modern_page_config() | |
| st.markdown("## 🎨 Image Compressor — K-Means Color Quantization") | |
| st.caption("Reduce image size by limiting its color palette while keeping it visually appealing.") | |
| with st.sidebar: | |
| st.markdown("### Controls") | |
| n_colors = st.slider("Number of colors", min_value=2, max_value=64, value=16, step=1) | |
| # Use a slightly lower default quality to help reduce sizes, and default WEBP | |
| quality = st.slider("Output quality", min_value=10, max_value=100, value=85, step=1) | |
| output_format = st.selectbox("Output format", options=["PNG", "JPEG", "WEBP"], index=2) | |
| random_state = st.number_input("Random seed", min_value=0, max_value=2**31 - 1, value=42, step=1) | |
| max_side = st.slider("Process at max side (px)", min_value=256, max_value=4096, value=1024, step=128) | |
| st.caption("Tip: JPEG→PNG can increase size. Prefer JPEG/WEBP for photos.") | |
| # Always train per image — do not use any pretrained model | |
| use_pretrained = False | |
| mode = "Train per image" | |
| centers = None | |
| uploaded = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"]) | |
| if uploaded is None: | |
| st.stop() | |
| # Read original | |
| original_bytes = uploaded.read() | |
| original_img = Image.open(io.BytesIO(original_bytes)).convert("RGB") | |
| # Process/compress | |
| with st.spinner("Compressing with K-Means…"): | |
| quant_img, used_centers = compress_image( | |
| image=original_img, | |
| mode=mode, | |
| n_colors=n_colors, | |
| use_pretrained=use_pretrained, | |
| pretrained_centers=centers, | |
| random_state=random_state, | |
| max_side=max_side, | |
| ) | |
| # Prepare bytes for download | |
| compressed_bytes = image_bytes(quant_img, fmt=output_format, quality=quality) | |
| # Metrics | |
| orig_size = len(original_bytes) | |
| new_size = len(compressed_bytes) | |
| saving = orig_size - new_size | |
| saving_pct = (saving / orig_size * 100.0) if orig_size > 0 else 0.0 | |
| col_m1, col_m2, col_m3 = st.columns(3) | |
| col_m1.metric("Original size", human_size(orig_size)) | |
| col_m2.metric("Compressed size", human_size(new_size)) | |
| col_m3.metric("Saved", f"{human_size(max(saving, 0))}", delta=f"{saving_pct:.1f}%") | |
| # Preview side-by-side | |
| left, right = st.columns(2) | |
| with left: | |
| st.markdown("#### Before") | |
| st.image(original_img, use_container_width=True) | |
| with right: | |
| st.markdown("#### After") | |
| st.image(quant_img, use_container_width=True) | |
| # Palette preview | |
| if used_centers is not None and used_centers.size > 0: | |
| st.markdown("#### Extracted palette") | |
| k = used_centers.shape[0] | |
| # Create a small swatch image displaying centers | |
| swatch_h, swatch_w = 40, 40 * k | |
| swatch = np.zeros((swatch_h, swatch_w, 3), dtype=np.uint8) | |
| for i, color in enumerate(used_centers.astype(np.uint8)): | |
| swatch[:, i * 40 : (i + 1) * 40, :] = color | |
| st.image(Image.fromarray(swatch), caption=f"{k} colors", use_column_width=False) | |
| # Download | |
| file_root, _ = os.path.splitext(uploaded.name) | |
| outfile = f"{file_root}_compressed.{output_format.lower()}" | |
| st.download_button( | |
| label="Download compressed image", | |
| data=compressed_bytes, | |
| file_name=outfile, | |
| mime=Image.MIME.get(output_format.upper(), f"image/{output_format.lower()}"), | |
| ) | |
| with st.expander("Advanced details"): | |
| st.write({ | |
| "quantization_mode": mode, | |
| "n_colors_requested": n_colors, | |
| "n_colors_used": int(used_centers.shape[0]) if used_centers is not None else None, | |
| "output_format": output_format, | |
| "quality": quality, | |
| "random_state": random_state, | |
| "max_side": max_side, | |
| }) | |
| if __name__ == "__main__": | |
| # When run via `streamlit run`, a ScriptRunContext exists. If not, avoid calling | |
| # Streamlit APIs directly to prevent "missing ScriptRunContext" warnings/errors. | |
| try: | |
| from streamlit.runtime.scriptrunner import get_script_run_ctx # type: ignore | |
| if get_script_run_ctx() is not None: | |
| main() | |
| else: | |
| print("This app must be started with: streamlit run Image_Compressor/app.py") | |
| except Exception: | |
| # Older/newer Streamlit versions may not expose get_script_run_ctx. | |
| # Fallback: advise the correct launch command. | |
| print("This app must be started with: streamlit run Image_Compressor/app.py") | |