Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import io | |
| def load_model(): | |
| """ | |
| Loads the pre-trained Keras model from a local file. | |
| The custom loss and metrics must be provided when loading the model. | |
| """ | |
| model = keras.models.load_model( | |
| "naxi_lowlight.keras", | |
| custom_objects={ | |
| "charbonnier_loss": lambda y_true, y_pred: tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + 1e-6)), | |
| "psnr_metric": lambda y_true, y_pred: tf.image.psnr(y_pred, y_true, max_val=1.0) | |
| } | |
| ) | |
| return model | |
| # Load the model once when the app starts | |
| model = load_model() | |
| def enhance_image_pil(pil_img): | |
| """ | |
| Enhances a low-light PIL Image using the loaded Keras model. | |
| The function preprocesses the image for the model, makes a prediction, | |
| and post-processes the output back into a PIL Image. | |
| """ | |
| # Convert PIL Image to a NumPy array and normalize to [0, 1] | |
| image = keras.utils.img_to_array(pil_img).astype("float32") / 255.0 | |
| # Add a batch dimension | |
| image = np.expand_dims(image, axis=0) | |
| # Predict the enhanced image | |
| output = model.predict(image)[0] | |
| # Post-process the output: clip values, scale back to [0, 255], and convert to uint8 | |
| output = np.clip(output * 255.0, 0, 255).astype(np.uint8) | |
| # Convert the NumPy array back to a PIL Image | |
| return Image.fromarray(output) | |
| st.set_page_config(page_title="NaxiLowLight", layout="centered") | |
| st.title("🌙 NaxiLowLight Enhancer") | |
| uploaded_file = st.file_uploader("Upload a low‑light image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| # Open the uploaded file as a PIL Image | |
| img = Image.open(uploaded_file).convert("RGB") | |
| # Process the image using the custom model and a standard autocontrast function | |
| enhanced = enhance_image_pil(img) | |
| autocon = ImageOps.autocontrast(img) | |
| # Display the original, autocontrasted, and enhanced images side-by-side | |
| cols = st.columns(3) | |
| cols[0].image(img, caption="Original") | |
| cols[1].image(autocon, caption="Autocontrast") | |
| cols[2].image(enhanced, caption="Enhanced") | |
| # Create a download button for the enhanced image | |
| buf = io.BytesIO() | |
| enhanced.save(buf, format="PNG") | |
| st.download_button( | |
| "Download enhanced", | |
| buf.getvalue(), | |
| "enhanced.png", | |
| "image/png" | |
| ) |