Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import keras | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import os | |
| # Load the Keras model (ensure this is the correct model path) | |
| model_path = './weights.005.h5' | |
| # Load model | |
| try: | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at {model_path}") | |
| model = keras.models.load_model(model_path) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None # Ensure model is not used if loading fails | |
| def hex_to_rgb(hex_color): | |
| if not hex_color.startswith('#') or len(hex_color) != 7: | |
| raise ValueError(f"Invalid hex color format: {hex_color}. Expected format is #RRGGBB.") | |
| hex_color = hex_color.lstrip('#') | |
| return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4)) | |
| def preprocess_image(image): | |
| resized_image = cv2.resize(image, (480, 480)) # Resize to match model input | |
| resized_image = resized_image.astype('float32') / 255.0 # Normalize | |
| return np.expand_dims(resized_image, axis=0) | |
| def predict(model, im): | |
| preprocessed_image = preprocess_image(im) | |
| mask = model.predict(preprocessed_image)[0] | |
| print(f"Predicted mask shape: {mask.shape}") # Log shape of the mask | |
| mask = (mask > 0.5).astype(np.uint8) # Apply threshold to create binary mask | |
| print(f"Mask unique values: {np.unique(mask)}") # Log unique values in mask | |
| return np.expand_dims(mask, axis=-1) | |
| def recolor(im, mask, color): | |
| color_layer = np.zeros_like(im) | |
| color_layer[:] = color | |
| recolored = np.where(mask == 1, color_layer, im) # Apply color where mask is 1 | |
| return recolored | |
| def change_hair_color(image, hex_color): | |
| try: | |
| rgb_color = hex_to_rgb(hex_color) | |
| # Convert Gradio image to OpenCV format (numpy array) | |
| src_image = np.array(image) | |
| if src_image is None: | |
| raise ValueError("Failed to read the uploaded image. Please check the file format.") | |
| print(f"Input image shape: {src_image.shape}") # Log shape of the input image | |
| # Predict the hair mask | |
| mask = predict(model, src_image) | |
| # Recolor the hair using the mask | |
| recolored_image = recolor(src_image, mask, rgb_color) | |
| # Prepare mask and recolored image for display | |
| mask_image = Image.fromarray(mask.squeeze() * 255) # Convert to image format | |
| result_image = Image.fromarray(cv2.cvtColor(recolored_image, cv2.COLOR_BGR2RGB)) # Convert BGR to RGB | |
| return mask_image, result_image | |
| except Exception as e: | |
| print(f"Error during hair color change: {e}") | |
| return None, None # Return None if there's an error | |
| # Gradio interface | |
| with gr.Blocks() as gradio_app: | |
| gr.Markdown('## Upload a photo and select a new hair color!') | |
| with gr.Row(): | |
| with gr.Column(): | |
| face_file = gr.Image(label="Upload your photo") | |
| color = gr.ColorPicker(label="Select a color") | |
| submit = gr.Button("Change Hair Color", variant="primary") | |
| with gr.Column(): | |
| result = gr.Image(label="Recolored Image") | |
| mask_result = gr.Image(label="Hair Mask") | |
| submit.click( | |
| fn=change_hair_color, | |
| inputs=[face_file, color], | |
| outputs=[mask_result, result] | |
| ) | |
| if __name__ == "__main__": | |
| gradio_app.launch(share=True) | |