low_light_naxi / app.py
NEXAS's picture
Update app.py
5f900b6 verified
import streamlit as st
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
from tensorflow import keras
import io
@st.cache_resource
def load_model():
"""
Loads the pre-trained Keras model from a local file.
The custom loss and metrics must be provided when loading the model.
"""
model = keras.models.load_model(
"naxi_lowlight.keras",
custom_objects={
"charbonnier_loss": lambda y_true, y_pred: tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + 1e-6)),
"psnr_metric": lambda y_true, y_pred: tf.image.psnr(y_pred, y_true, max_val=1.0)
}
)
return model
# Load the model once when the app starts
model = load_model()
def enhance_image_pil(pil_img):
"""
Enhances a low-light PIL Image using the loaded Keras model.
The function preprocesses the image for the model, makes a prediction,
and post-processes the output back into a PIL Image.
"""
# Convert PIL Image to a NumPy array and normalize to [0, 1]
image = keras.utils.img_to_array(pil_img).astype("float32") / 255.0
# Add a batch dimension
image = np.expand_dims(image, axis=0)
# Predict the enhanced image
output = model.predict(image)[0]
# Post-process the output: clip values, scale back to [0, 255], and convert to uint8
output = np.clip(output * 255.0, 0, 255).astype(np.uint8)
# Convert the NumPy array back to a PIL Image
return Image.fromarray(output)
st.set_page_config(page_title="NaxiLowLight", layout="centered")
st.title("🌙 NaxiLowLight Enhancer")
uploaded_file = st.file_uploader("Upload a low‑light image", type=["jpg", "jpeg", "png"])
if uploaded_file:
# Open the uploaded file as a PIL Image
img = Image.open(uploaded_file).convert("RGB")
# Process the image using the custom model and a standard autocontrast function
enhanced = enhance_image_pil(img)
autocon = ImageOps.autocontrast(img)
# Display the original, autocontrasted, and enhanced images side-by-side
cols = st.columns(3)
cols[0].image(img, caption="Original")
cols[1].image(autocon, caption="Autocontrast")
cols[2].image(enhanced, caption="Enhanced")
# Create a download button for the enhanced image
buf = io.BytesIO()
enhanced.save(buf, format="PNG")
st.download_button(
"Download enhanced",
buf.getvalue(),
"enhanced.png",
"image/png"
)