Codingacademey
Add application file
5de07fb
Raw
History Blame Contribute Delete
9.29 kB
import io
import os
import pickle
from typing import Optional, Tuple
import numpy as np
from PIL import Image
import streamlit as st
def set_modern_page_config() -> None:
st.set_page_config(
page_title="Image Compressor (K-Means)",
page_icon="🎨",
layout="wide",
initial_sidebar_state="expanded",
)
# Subtle custom styling for a modern look
st.markdown(
"""
<style>
/* Hide Streamlit default header/footer */
header[data-testid="stHeader"] { display: none; }
footer { visibility: hidden; }
/* Card-like containers */
.block-container { padding-top: 2rem; }
div[data-testid="stSidebar"] { backdrop-filter: blur(6px); }
/* Buttons */
.stButton>button {
border-radius: 10px;
background: linear-gradient(135deg, #6B73FF 0%, #000DFF 100%);
color: white;
border: none;
box-shadow: 0 8px 24px rgba(0, 13, 255, 0.24);
}
.stDownloadButton>button {
border-radius: 10px;
}
/* Metrics */
[data-testid="stMetricDelta"] { font-weight: 600; }
</style>
""",
unsafe_allow_html=True,
)
def try_load_pretrained_pipeline(pickle_path: str) -> Optional[object]:
if not os.path.exists(pickle_path):
return None
try:
with open(pickle_path, "rb") as f:
obj = pickle.load(f)
return obj
except Exception:
return None
def extract_cluster_centers_from_model(model: object) -> Optional[np.ndarray]:
"""Try to extract K-Means cluster centers from a variety of common objects.
Returns an array of shape (k, 3) in RGB space if available.
"""
if model is None:
return None
# Direct KMeans-like object
centers = getattr(model, "cluster_centers_", None)
if centers is not None:
return centers
# sklearn Pipeline-like, last step a KMeans
steps = getattr(model, "steps", None)
if steps and len(steps) > 0:
last_estimator = steps[-1][1]
centers = getattr(last_estimator, "cluster_centers_", None)
if centers is not None:
return centers
# Dict-like
if isinstance(model, dict) and "cluster_centers_" in model:
return np.asarray(model["cluster_centers_"])
return None
def quantize_with_centers(image_array: np.ndarray, centers: np.ndarray) -> np.ndarray:
"""Map each pixel to the nearest color in centers (RGB).
image_array: (H, W, 3), dtype uint8
centers: (k, 3), float or uint8
"""
# Flatten to (N,3)
pixels = image_array.reshape(-1, 3).astype(np.float32)
centers = centers.astype(np.float32)
# Compute squared distances to all centers efficiently
# distances: (N, k)
# Using (a-b)^2 = a^2 + b^2 - 2ab for efficiency
px2 = np.sum(pixels * pixels, axis=1, keepdims=True) # (N,1)
ct2 = np.sum(centers * centers, axis=1, keepdims=True).T # (1,k)
dots = pixels @ centers.T # (N,k)
dist2 = px2 + ct2 - 2.0 * dots
labels = np.argmin(dist2, axis=1)
quantized = centers[labels].astype(np.uint8)
return quantized.reshape(image_array.shape)
def kmeans_quantize(image_array: np.ndarray, n_colors: int, random_state: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fit K-Means on the image pixels and return quantized image and cluster centers."""
from sklearn.cluster import KMeans
pixels = image_array.reshape(-1, 3).astype(np.float32)
model = KMeans(n_clusters=n_colors, n_init=4, random_state=random_state)
labels = model.fit_predict(pixels)
centers = model.cluster_centers_.astype(np.uint8)
quantized = centers[labels].reshape(image_array.shape)
return quantized, centers
def compress_image(
image: Image.Image,
mode: str,
n_colors: int,
use_pretrained: bool,
pretrained_centers: Optional[np.ndarray],
random_state: int,
max_side: int,
) -> Tuple[Image.Image, Optional[np.ndarray]]:
# Resize for faster processing if needed
original_size = image.size
img = image.copy()
img.thumbnail((max_side, max_side))
image_np = np.array(img.convert("RGB"))
if mode == "Auto (use pretrained if available)" and use_pretrained and pretrained_centers is not None:
quant_np = quantize_with_centers(image_np, pretrained_centers)
centers = pretrained_centers.astype(np.uint8)
else:
quant_np, centers = kmeans_quantize(image_np, n_colors=n_colors, random_state=random_state)
quant_img = Image.fromarray(quant_np, mode="RGB")
# If we resized for speed, upscale back to original using nearest to preserve palette
if quant_img.size != original_size:
quant_img = quant_img.resize(original_size, resample=Image.NEAREST)
return quant_img, centers
def image_bytes(img: Image.Image, fmt: str, quality: int) -> bytes:
buf = io.BytesIO()
params = {}
if fmt.upper() in {"JPEG", "WEBP"}:
params["quality"] = int(quality)
if fmt.upper() == "JPEG":
params["optimize"] = True
params["progressive"] = True
if fmt.upper() == "PNG":
# Pillow uses 0-9 for compress_level (opposite of quality). Map roughly.
compress_level = int(np.clip((100 - quality) / 11, 0, 9))
params["compress_level"] = compress_level
img.save(buf, format=fmt.upper(), **params)
return buf.getvalue()
def human_size(num_bytes: int) -> str:
units = ["B", "KB", "MB", "GB"]
size = float(num_bytes)
for unit in units:
if size < 1024.0 or unit == units[-1]:
return f"{size:.2f} {unit}"
size /= 1024.0
return f"{size:.2f} GB"
def main() -> None:
set_modern_page_config()
st.markdown("## 🎨 Image Compressor — K-Means Color Quantization")
st.caption("Reduce image size by limiting its color palette while keeping it visually appealing.")
with st.sidebar:
st.markdown("### Controls")
n_colors = st.slider("Number of colors", min_value=2, max_value=64, value=16, step=1)
# Use a slightly lower default quality to help reduce sizes, and default WEBP
quality = st.slider("Output quality", min_value=10, max_value=100, value=85, step=1)
output_format = st.selectbox("Output format", options=["PNG", "JPEG", "WEBP"], index=2)
random_state = st.number_input("Random seed", min_value=0, max_value=2**31 - 1, value=42, step=1)
max_side = st.slider("Process at max side (px)", min_value=256, max_value=4096, value=1024, step=128)
st.caption("Tip: JPEG→PNG can increase size. Prefer JPEG/WEBP for photos.")
# Always train per image — do not use any pretrained model
use_pretrained = False
mode = "Train per image"
centers = None
uploaded = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg", "webp"])
if uploaded is None:
st.stop()
# Read original
original_bytes = uploaded.read()
original_img = Image.open(io.BytesIO(original_bytes)).convert("RGB")
# Process/compress
with st.spinner("Compressing with K-Means…"):
quant_img, used_centers = compress_image(
image=original_img,
mode=mode,
n_colors=n_colors,
use_pretrained=use_pretrained,
pretrained_centers=centers,
random_state=random_state,
max_side=max_side,
)
# Prepare bytes for download
compressed_bytes = image_bytes(quant_img, fmt=output_format, quality=quality)
# Metrics
orig_size = len(original_bytes)
new_size = len(compressed_bytes)
saving = orig_size - new_size
saving_pct = (saving / orig_size * 100.0) if orig_size > 0 else 0.0
col_m1, col_m2, col_m3 = st.columns(3)
col_m1.metric("Original size", human_size(orig_size))
col_m2.metric("Compressed size", human_size(new_size))
col_m3.metric("Saved", f"{human_size(max(saving, 0))}", delta=f"{saving_pct:.1f}%")
# Preview side-by-side
left, right = st.columns(2)
with left:
st.markdown("#### Before")
st.image(original_img, use_container_width=True)
with right:
st.markdown("#### After")
st.image(quant_img, use_container_width=True)
# Palette preview
if used_centers is not None and used_centers.size > 0:
st.markdown("#### Extracted palette")
k = used_centers.shape[0]
# Create a small swatch image displaying centers
swatch_h, swatch_w = 40, 40 * k
swatch = np.zeros((swatch_h, swatch_w, 3), dtype=np.uint8)
for i, color in enumerate(used_centers.astype(np.uint8)):
swatch[:, i * 40 : (i + 1) * 40, :] = color
st.image(Image.fromarray(swatch), caption=f"{k} colors", use_column_width=False)
# Download
file_root, _ = os.path.splitext(uploaded.name)
outfile = f"{file_root}_compressed.{output_format.lower()}"
st.download_button(
label="Download compressed image",
data=compressed_bytes,
file_name=outfile,
mime=Image.MIME.get(output_format.upper(), f"image/{output_format.lower()}"),
)
with st.expander("Advanced details"):
st.write({
"quantization_mode": mode,
"n_colors_requested": n_colors,
"n_colors_used": int(used_centers.shape[0]) if used_centers is not None else None,
"output_format": output_format,
"quality": quality,
"random_state": random_state,
"max_side": max_side,
})
if __name__ == "__main__":
# When run via `streamlit run`, a ScriptRunContext exists. If not, avoid calling
# Streamlit APIs directly to prevent "missing ScriptRunContext" warnings/errors.
try:
from streamlit.runtime.scriptrunner import get_script_run_ctx # type: ignore
if get_script_run_ctx() is not None:
main()
else:
print("This app must be started with: streamlit run Image_Compressor/app.py")
except Exception:
# Older/newer Streamlit versions may not expose get_script_run_ctx.
# Fallback: advise the correct launch command.
print("This app must be started with: streamlit run Image_Compressor/app.py")