colourchangeapp / app.py
sidhi251287's picture
Update app.py
a7092be verified
import gradio as gr
import keras
import numpy as np
from PIL import Image
import cv2
import os
# Load the Keras model (ensure this is the correct model path)
model_path = './weights.005.h5'
# Load model
try:
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
model = keras.models.load_model(model_path)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None # Ensure model is not used if loading fails
def hex_to_rgb(hex_color):
if not hex_color.startswith('#') or len(hex_color) != 7:
raise ValueError(f"Invalid hex color format: {hex_color}. Expected format is #RRGGBB.")
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
def preprocess_image(image):
resized_image = cv2.resize(image, (480, 480)) # Resize to match model input
resized_image = resized_image.astype('float32') / 255.0 # Normalize
return np.expand_dims(resized_image, axis=0)
def predict(model, im):
preprocessed_image = preprocess_image(im)
mask = model.predict(preprocessed_image)[0]
print(f"Predicted mask shape: {mask.shape}") # Log shape of the mask
mask = (mask > 0.5).astype(np.uint8) # Apply threshold to create binary mask
print(f"Mask unique values: {np.unique(mask)}") # Log unique values in mask
return np.expand_dims(mask, axis=-1)
def recolor(im, mask, color):
color_layer = np.zeros_like(im)
color_layer[:] = color
recolored = np.where(mask == 1, color_layer, im) # Apply color where mask is 1
return recolored
def change_hair_color(image, hex_color):
try:
rgb_color = hex_to_rgb(hex_color)
# Convert Gradio image to OpenCV format (numpy array)
src_image = np.array(image)
if src_image is None:
raise ValueError("Failed to read the uploaded image. Please check the file format.")
print(f"Input image shape: {src_image.shape}") # Log shape of the input image
# Predict the hair mask
mask = predict(model, src_image)
# Recolor the hair using the mask
recolored_image = recolor(src_image, mask, rgb_color)
# Prepare mask and recolored image for display
mask_image = Image.fromarray(mask.squeeze() * 255) # Convert to image format
result_image = Image.fromarray(cv2.cvtColor(recolored_image, cv2.COLOR_BGR2RGB)) # Convert BGR to RGB
return mask_image, result_image
except Exception as e:
print(f"Error during hair color change: {e}")
return None, None # Return None if there's an error
# Gradio interface
with gr.Blocks() as gradio_app:
gr.Markdown('## Upload a photo and select a new hair color!')
with gr.Row():
with gr.Column():
face_file = gr.Image(label="Upload your photo")
color = gr.ColorPicker(label="Select a color")
submit = gr.Button("Change Hair Color", variant="primary")
with gr.Column():
result = gr.Image(label="Recolored Image")
mask_result = gr.Image(label="Hair Mask")
submit.click(
fn=change_hair_color,
inputs=[face_file, color],
outputs=[mask_result, result]
)
if __name__ == "__main__":
gradio_app.launch(share=True)