import streamlit as st import numpy as np from PIL import Image, ImageOps import tensorflow as tf from tensorflow import keras import io @st.cache_resource def load_model(): """ Loads the pre-trained Keras model from a local file. The custom loss and metrics must be provided when loading the model. """ model = keras.models.load_model( "naxi_lowlight.keras", custom_objects={ "charbonnier_loss": lambda y_true, y_pred: tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + 1e-6)), "psnr_metric": lambda y_true, y_pred: tf.image.psnr(y_pred, y_true, max_val=1.0) } ) return model # Load the model once when the app starts model = load_model() def enhance_image_pil(pil_img): """ Enhances a low-light PIL Image using the loaded Keras model. The function preprocesses the image for the model, makes a prediction, and post-processes the output back into a PIL Image. """ # Convert PIL Image to a NumPy array and normalize to [0, 1] image = keras.utils.img_to_array(pil_img).astype("float32") / 255.0 # Add a batch dimension image = np.expand_dims(image, axis=0) # Predict the enhanced image output = model.predict(image)[0] # Post-process the output: clip values, scale back to [0, 255], and convert to uint8 output = np.clip(output * 255.0, 0, 255).astype(np.uint8) # Convert the NumPy array back to a PIL Image return Image.fromarray(output) st.set_page_config(page_title="NaxiLowLight", layout="centered") st.title("🌙 NaxiLowLight Enhancer") uploaded_file = st.file_uploader("Upload a low‑light image", type=["jpg", "jpeg", "png"]) if uploaded_file: # Open the uploaded file as a PIL Image img = Image.open(uploaded_file).convert("RGB") # Process the image using the custom model and a standard autocontrast function enhanced = enhance_image_pil(img) autocon = ImageOps.autocontrast(img) # Display the original, autocontrasted, and enhanced images side-by-side cols = st.columns(3) cols[0].image(img, caption="Original") cols[1].image(autocon, caption="Autocontrast") cols[2].image(enhanced, caption="Enhanced") # Create a download button for the enhanced image buf = io.BytesIO() enhanced.save(buf, format="PNG") st.download_button( "Download enhanced", buf.getvalue(), "enhanced.png", "image/png" )