Spaces:
Build error
Build error
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import pickle | |
| from PIL import Image | |
| import gradio as gr | |
| # CRITICAL: Define the custom InstanceNormalization layer used in training | |
| class InstanceNormalization(tf.keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super(InstanceNormalization, self).__init__(**kwargs) | |
| self.epsilon = epsilon | |
| def build(self, input_shape): | |
| depth = input_shape[-1] | |
| self.scale = self.add_weight( | |
| shape=[depth], | |
| initializer=tf.random_normal_initializer(1., 0.02), | |
| trainable=True, | |
| name='scale' | |
| ) | |
| self.offset = self.add_weight( | |
| shape=[depth], | |
| initializer='zeros', | |
| trainable=True, | |
| name='offset' | |
| ) | |
| super().build(input_shape) | |
| def call(self, x): | |
| mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True) | |
| inv = tf.math.rsqrt(variance + self.epsilon) | |
| normalized = (x - mean) * inv | |
| return self.scale * normalized + self.offset | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({'epsilon': self.epsilon}) | |
| return config | |
| # Set up TensorFlow for compatibility | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| tf.keras.mixed_precision.set_global_policy('float32') | |
| class MultiAttributeClassifier: | |
| def __init__(self): | |
| # Define categories for classification | |
| self.categories = ['content', 'style', 'time_of_day', 'weather'] | |
| self.models = {} | |
| self.encoders = {} | |
| self.gan_models = {} | |
| # Define custom objects FIRST (before using them) | |
| self.custom_objects = { | |
| 'InstanceNormalization': InstanceNormalization, | |
| 'tf': tf | |
| } | |
| # Load models | |
| self.load_classification_models() | |
| self.load_gan_models() | |
| def load_classification_models(self): | |
| """Load all classification models and encoders""" | |
| print("Loading classification models...") | |
| print(f"π Looking for models in: models/classification/") | |
| # First, let's see what's actually in the classification folder | |
| classification_path = "models/classification" | |
| if os.path.exists(classification_path): | |
| print(f"π Found classification directory") | |
| files = os.listdir(classification_path) | |
| print(f"π Available files: {files}") | |
| else: | |
| print(f"β Classification directory not found: {classification_path}") | |
| return | |
| for category in self.categories: | |
| try: | |
| # Load model from correct path | |
| model_path = f"models/classification/{category}_model.h5" | |
| if os.path.exists(model_path): | |
| print(f"π Loading model: {model_path}") | |
| try: | |
| # Try normal loading first | |
| self.models[category] = tf.keras.models.load_model(model_path) | |
| except Exception as e1: | |
| print(f" β οΈ Normal loading failed: {e1}") | |
| try: | |
| # Try with compile=False and custom objects | |
| self.models[category] = tf.keras.models.load_model( | |
| model_path, | |
| compile=False, | |
| custom_objects=self.custom_objects | |
| ) | |
| print(f" β Loaded {category} with custom_objects") | |
| except Exception as e2: | |
| print(f" β Failed to load {category}: {e2}") | |
| continue | |
| print(f"β Loaded {category} model ({os.path.getsize(model_path)/1024/1024:.1f} MB)") | |
| # Load encoder | |
| encoder_path = f"models/classification/{category}_encoder.pkl" | |
| if os.path.exists(encoder_path): | |
| with open(encoder_path, 'rb') as f: | |
| encoder_data = pickle.load(f) | |
| # Handle different encoder formats | |
| if hasattr(encoder_data, 'classes_'): | |
| # Standard LabelEncoder | |
| self.encoders[category] = encoder_data | |
| print(f"β Loaded {category} encoder (LabelEncoder) - {len(encoder_data.classes_)} classes") | |
| elif isinstance(encoder_data, dict): | |
| # Dict format - create a wrapper | |
| class EncoderWrapper: | |
| def __init__(self, class_dict): | |
| if 'classes_' in class_dict: | |
| self.classes_ = class_dict['classes_'] | |
| elif 'classes' in class_dict: | |
| self.classes_ = class_dict['classes'] | |
| else: | |
| # Try to extract classes from dict keys/values | |
| self.classes_ = list(class_dict.keys()) if class_dict else ['unknown'] | |
| self.encoders[category] = EncoderWrapper(encoder_data) | |
| print(f"β Loaded {category} encoder (Dict format) - {len(self.encoders[category].classes_)} classes") | |
| print(f" Classes: {self.encoders[category].classes_}") | |
| else: | |
| print(f"β οΈ Unknown encoder format for {category}: {type(encoder_data)}") | |
| print(f" Content preview: {str(encoder_data)[:200]}...") | |
| else: | |
| print(f"β οΈ {category} encoder not found at {encoder_path}") | |
| else: | |
| print(f"β {category} model not found at {model_path}") | |
| except Exception as e: | |
| print(f"β Failed to load {category}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print(f"π― Successfully loaded {len(self.models)} classification models") | |
| def load_gan_models(self): | |
| """Load all GAN models for style transfer""" | |
| print("Loading GAN models...") | |
| # First, let's scan what's actually in the GAN folders | |
| gan_base_path = "models/gan" | |
| if os.path.exists(gan_base_path): | |
| print(f"π Found GAN models directory: {gan_base_path}") | |
| for folder in os.listdir(gan_base_path): | |
| folder_path = os.path.join(gan_base_path, folder) | |
| if os.path.isdir(folder_path): | |
| print(f"π GAN folder: {folder}") | |
| for file in os.listdir(folder_path): | |
| print(f" π {file}") | |
| # Try multiple possible file name patterns | |
| gan_paths = { | |
| # Day/Night models - try multiple naming patterns | |
| 'day_to_night': [ | |
| 'models/gan/day_night/day_to_night_generator_final.keras', | |
| 'models/gan/day_night/day_to_night_generator.keras', | |
| 'models/gan/day_night/day_to_night.keras', | |
| 'models/gan/day_night/generator_day_to_night.keras' | |
| ], | |
| 'night_to_day': [ | |
| 'models/gan/day_night/night_to_day_generator_final.keras', | |
| 'models/gan/day_night/night_to_day_generator.keras', | |
| 'models/gan/day_night/night_to_day.keras', | |
| 'models/gan/day_night/generator_night_to_day.keras' | |
| ], | |
| # Foggy/Clear models | |
| 'foggy_to_clear': [ | |
| 'models/gan/foggy/foggy_to_normal_generator_final.keras', | |
| 'models/gan/foggy/foggy_to_clear_generator.keras', | |
| 'models/gan/foggy/foggy_to_clear.keras' | |
| ], | |
| 'clear_to_foggy': [ | |
| 'models/gan/foggy/normal_to_foggy_generator_final.keras', | |
| 'models/gan/foggy/clear_to_foggy_generator.keras', | |
| 'models/gan/foggy/clear_to_foggy.keras' | |
| ], | |
| # Japanese art models | |
| 'photo_to_japanese': [ | |
| 'models/gan/japanese/photo_to_ukiyoe_generator.keras', | |
| 'models/gan/japanese/photo_to_japanese_generator.keras', | |
| 'models/gan/japanese/photo_to_japanese.keras' | |
| ], | |
| 'japanese_to_photo': [ | |
| 'models/gan/japanese/ukiyoe_to_photo_generator.keras', | |
| 'models/gan/japanese/japanese_to_photo_generator.keras', | |
| 'models/gan/japanese/japanese_to_photo.keras' | |
| ], | |
| # Season models | |
| 'summer_to_winter': [ | |
| 'models/gan/summer_winter/summer_to_winter_generator_final.keras', | |
| 'models/gan/summer_winter/summer_to_winter_generator.keras', | |
| 'models/gan/summer_winter/summer_to_winter.keras' | |
| ], | |
| 'winter_to_summer': [ | |
| 'models/gan/summer_winter/winter_to_summer_generator_final.keras', | |
| 'models/gan/summer_winter/winter_to_summer_generator.keras', | |
| 'models/gan/summer_winter/winter_to_summer.keras' | |
| ] | |
| } | |
| for model_name, possible_paths in gan_paths.items(): | |
| model_loaded = False | |
| for model_path in possible_paths: | |
| try: | |
| if os.path.exists(model_path): | |
| print(f"π Trying to load: {model_path}") | |
| # Try loading with different compatibility options | |
| try: | |
| # First try: Normal loading | |
| self.gan_models[model_name] = tf.keras.models.load_model(model_path) | |
| except Exception as e1: | |
| print(f" β οΈ Normal loading failed: {e1}") | |
| try: | |
| # Second try: Load with compile=False (ignore training config) | |
| self.gan_models[model_name] = tf.keras.models.load_model(model_path, compile=False) | |
| print(f" β Loaded with compile=False") | |
| except Exception as e2: | |
| print(f" β οΈ compile=False failed: {e2}") | |
| try: | |
| # Third try: Load with custom objects (CRITICAL for your models) | |
| self.gan_models[model_name] = tf.keras.models.load_model( | |
| model_path, | |
| compile=False, | |
| custom_objects=self.custom_objects | |
| ) | |
| print(f" β Loaded with custom_objects (InstanceNormalization)") | |
| except Exception as e3: | |
| print(f" β All loading methods failed: {e3}") | |
| # Print the actual error for debugging | |
| print(f" π Error details: {str(e3)}") | |
| raise e3 | |
| print(f"β Loaded GAN: {model_name} from {model_path}") | |
| model_loaded = True | |
| break | |
| except Exception as e: | |
| print(f"β Failed to load {model_path}: {e}") | |
| continue | |
| if not model_loaded: | |
| print(f"β οΈ Could not load GAN model: {model_name}") | |
| # Let's also scan the actual directory to see what files exist | |
| folder_map = { | |
| 'day_to_night': 'day_night', | |
| 'night_to_day': 'day_night', | |
| 'foggy_to_clear': 'foggy', | |
| 'clear_to_foggy': 'foggy', | |
| 'photo_to_japanese': 'japanese', | |
| 'japanese_to_photo': 'japanese', | |
| 'summer_to_winter': 'summer_winter', | |
| 'winter_to_summer': 'summer_winter' | |
| } | |
| if model_name in folder_map: | |
| folder_path = f"models/gan/{folder_map[model_name]}" | |
| if os.path.exists(folder_path): | |
| print(f" π Available files in {folder_path}:") | |
| for file in os.listdir(folder_path): | |
| if file.endswith(('.keras', '.h5')): | |
| print(f" π {file}") | |
| print(f"π― Successfully loaded {len(self.gan_models)} GAN models") | |
| def preprocess_image(self, image): | |
| """Preprocess image for model input""" | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Resize image | |
| image = image.resize((224, 224)) | |
| # Convert to array and normalize | |
| img_array = np.array(image) | |
| if img_array.shape[-1] == 4: # RGBA | |
| img_array = img_array[:, :, :3] # Remove alpha channel | |
| # Normalize to [0, 1] | |
| img_array = img_array.astype(np.float32) / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| def predict_attributes(self, image): | |
| """Predict multiple attributes of an image""" | |
| preprocessed = self.preprocess_image(image) | |
| predictions = {} | |
| for category in self.categories: | |
| if category in self.models and category in self.encoders: | |
| try: | |
| # Get model prediction | |
| pred = self.models[category].predict(preprocessed, verbose=0) | |
| # Get predicted class | |
| predicted_class_idx = np.argmax(pred, axis=1)[0] | |
| confidence = float(np.max(pred)) | |
| # Get class name from encoder - handle different formats | |
| try: | |
| if hasattr(self.encoders[category], 'classes_'): | |
| classes = self.encoders[category].classes_ | |
| if predicted_class_idx < len(classes): | |
| class_name = classes[predicted_class_idx] | |
| else: | |
| class_name = f"class_{predicted_class_idx}" | |
| else: | |
| class_name = f"class_{predicted_class_idx}" | |
| except Exception as e: | |
| print(f"Error getting class name for {category}: {e}") | |
| class_name = f"class_{predicted_class_idx}" | |
| predictions[category] = { | |
| 'class': class_name, | |
| 'confidence': confidence | |
| } | |
| except Exception as e: | |
| print(f"Error predicting {category}: {e}") | |
| predictions[category] = { | |
| 'class': 'unknown', | |
| 'confidence': 0.0 | |
| } | |
| else: | |
| # Fallback predictions if models not loaded | |
| fallback_predictions = { | |
| 'content': {'class': 'outdoor', 'confidence': 0.6}, | |
| 'style': {'class': 'realistic', 'confidence': 0.7}, | |
| 'time_of_day': {'class': 'day', 'confidence': 0.8}, | |
| 'weather': {'class': 'clear', 'confidence': 0.8} | |
| } | |
| predictions[category] = fallback_predictions.get(category, {'class': 'unknown', 'confidence': 0.0}) | |
| return predictions | |
| def apply_style_transfer(self, image, transformation): | |
| """Apply style transfer using trained GAN models""" | |
| if transformation not in self.gan_models: | |
| return None, f"GAN model '{transformation}' not available" | |
| try: | |
| # Preprocess image for GAN (256x256, normalized to [-1, 1]) | |
| if isinstance(image, str): | |
| img = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| img = Image.fromarray(image) | |
| else: | |
| img = image | |
| # Resize to 256x256 for GAN | |
| img = img.resize((256, 256)) | |
| img_array = np.array(img) | |
| if img_array.shape[-1] == 4: # RGBA | |
| img_array = img_array[:, :, :3] # Remove alpha channel | |
| # Normalize to [-1, 1] for GAN | |
| img_array = (img_array.astype(np.float32) / 127.5) - 1.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Apply transformation | |
| model = self.gan_models[transformation] | |
| generated = model.predict(img_array, verbose=0) | |
| # Denormalize and convert back to image | |
| generated = (generated[0] + 1.0) * 127.5 | |
| generated = np.clip(generated, 0, 255).astype(np.uint8) | |
| return generated, "Transformation completed!" | |
| except Exception as e: | |
| print(f"Error in style transfer: {e}") | |
| return None, f"Error: {str(e)}" | |
| def get_style_recommendations(self, predictions): | |
| """Get style transfer recommendations based on predictions""" | |
| recommendations = [] | |
| # Time-based recommendations | |
| if 'time_of_day' in predictions: | |
| time_pred = predictions['time_of_day'] | |
| if time_pred['class'] == 'day' and time_pred['confidence'] > 0.7: | |
| recommendations.append({ | |
| 'transformation': 'day_to_night', | |
| 'confidence': time_pred['confidence'], | |
| 'description': f"Transform scene to night with {time_pred['confidence']*100:.0f}% confidence" | |
| }) | |
| elif time_pred['class'] == 'night' and time_pred['confidence'] > 0.7: | |
| recommendations.append({ | |
| 'transformation': 'night_to_day', | |
| 'confidence': time_pred['confidence'], | |
| 'description': f"Transform scene to day with {time_pred['confidence']*100:.0f}% confidence" | |
| }) | |
| # Weather-based recommendations | |
| if 'weather' in predictions: | |
| weather_pred = predictions['weather'] | |
| if weather_pred['class'] == 'clear' and weather_pred['confidence'] > 0.6: | |
| recommendations.append({ | |
| 'transformation': 'clear_to_foggy', | |
| 'confidence': weather_pred['confidence'], | |
| 'description': f"Add fog atmosphere with {weather_pred['confidence']*100:.0f}% confidence" | |
| }) | |
| elif weather_pred['class'] == 'foggy' and weather_pred['confidence'] > 0.6: | |
| recommendations.append({ | |
| 'transformation': 'foggy_to_clear', | |
| 'confidence': weather_pred['confidence'], | |
| 'description': f"Clear fog from scene with {weather_pred['confidence']*100:.0f}% confidence" | |
| }) | |
| # Content-based recommendations | |
| if 'content' in predictions: | |
| content_pred = predictions['content'] | |
| if content_pred['class'] in ['outdoor', 'landscape'] and content_pred['confidence'] > 0.6: | |
| recommendations.extend([ | |
| { | |
| 'transformation': 'summer_to_winter', | |
| 'confidence': 0.8, | |
| 'description': f"Transform scene to winter with snow and cold atmosphere" | |
| }, | |
| { | |
| 'transformation': 'winter_to_summer', | |
| 'confidence': 0.8, | |
| 'description': f"Transform scene to summer with warm, lush atmosphere" | |
| } | |
| ]) | |
| # Style-based recommendations | |
| if 'style' in predictions: | |
| style_pred = predictions['style'] | |
| if style_pred['class'] == 'realistic' and style_pred['confidence'] > 0.6: | |
| recommendations.append({ | |
| 'transformation': 'photo_to_japanese', | |
| 'confidence': style_pred['confidence'], | |
| 'description': f"Transform to Japanese ukiyo-e art style" | |
| }) | |
| return recommendations | |
| # Initialize classifier globally | |
| print("π Starting StyleTransfer App...") | |
| classifier = MultiAttributeClassifier() | |
| print(f"π― Initialization complete!") | |
| print(f" π Classification models loaded: {len(classifier.models)}") | |
| print(f" π¨ GAN models loaded: {len(classifier.gan_models)}") | |
| if len(classifier.models) > 0: | |
| print(f" β Available categories: {list(classifier.models.keys())}") | |
| if len(classifier.gan_models) > 0: | |
| print(f" β Available transformations: {list(classifier.gan_models.keys())}") | |
| print("="*50) | |
| def analyze_image(image): | |
| """Analyze uploaded image and provide style recommendations""" | |
| if image is None: | |
| choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] | |
| return "Please upload an image first.", gr.update(choices=choices_with_labels, value=None, visible=True), [] | |
| try: | |
| # Get predictions for all attributes | |
| predictions = classifier.predict_attributes(image) | |
| # Format analysis results | |
| analysis_text = "## π Image Analysis Results\n\n" | |
| for category, pred in predictions.items(): | |
| confidence_pct = pred['confidence'] * 100 | |
| analysis_text += f"**{category.replace('_', ' ').title()}:** {pred['class']} ({confidence_pct:.1f}% confidence)\n\n" | |
| # Get style recommendations | |
| recommendations = classifier.get_style_recommendations(predictions) | |
| # Format recommendations for display | |
| if recommendations: | |
| analysis_text += "## π¨ AI Suggestions\n\n" | |
| for rec in recommendations: | |
| analysis_text += f"**{rec['transformation'].replace('_', ' β ').title()}** ({rec['confidence']*100:.0f}%) {rec['description']}\n\n" | |
| else: | |
| analysis_text += "## π¨ AI Suggestions\n\nNo specific recommendations - but feel free to try any transformation!\n\n" | |
| analysis_text += "---\n**Choose any transformation(s) below - you're not limited to the suggestions!**" | |
| # Always return ALL available transformations, regardless of analysis | |
| choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] | |
| return analysis_text, gr.update(choices=choices_with_labels, value=None, visible=True), [] | |
| except Exception as e: | |
| print(f"Error in analysis: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Even if analysis fails, still show all transformations | |
| choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] | |
| return f"Error analyzing image: {str(e)}\n\n**All transformations still available below:**", gr.update(choices=choices_with_labels, value=None, visible=True), [] | |
| def apply_transformations(image, selected_transformations): | |
| """Apply selected style transformations""" | |
| if image is None: | |
| return "Please upload an image first.", [] | |
| if not selected_transformations: | |
| return "Please select at least one transformation.", [] | |
| results = [] | |
| status_messages = [] | |
| for transformation in selected_transformations: | |
| try: | |
| transformed_img, message = classifier.apply_style_transfer(image, transformation) | |
| if transformed_img is not None: | |
| results.append(transformed_img) | |
| status_messages.append(f"β {transformation.replace('_', ' β ').title()}: {message}") | |
| else: | |
| status_messages.append(f"β {transformation.replace('_', ' β ').title()}: {message}") | |
| except Exception as e: | |
| status_messages.append(f"β {transformation}: Error - {str(e)}") | |
| status_text = "\n".join(status_messages) | |
| return status_text, results | |
| # Available transformations for manual selection - show user-friendly names | |
| available_transformations = [ | |
| "day_to_night", "night_to_day", | |
| "clear_to_foggy", "foggy_to_clear", | |
| "photo_to_japanese", "japanese_to_photo", | |
| "summer_to_winter", "winter_to_summer" | |
| ] | |
| # User-friendly transformation names | |
| transformation_labels = { | |
| "day_to_night": "π βπ Day to Night", | |
| "night_to_day": "πβπ Night to Day", | |
| "clear_to_foggy": "βοΈβπ«οΈ Clear to Foggy", | |
| "foggy_to_clear": "π«οΈββοΈ Foggy to Clear", | |
| "photo_to_japanese": "π·βπ¨ Photo to Japanese Art", | |
| "japanese_to_photo": "π¨βπ· Japanese Art to Photo", | |
| "summer_to_winter": "πΏββοΈ Summer to Winter", | |
| "winter_to_summer": "βοΈβπΏ Winter to Summer" | |
| } | |
| # Create Gradio interface | |
| with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π¨ Intelligent Multi-Attribute Style Transfer") | |
| gr.Markdown("Upload an image and our AI will analyze it to provide smart suggestions - **but you can choose ANY transformation you want!**") | |
| gr.Markdown("π‘ **Tip:** You can skip analysis and apply transformations directly!") | |
| # Show available transformations | |
| gr.Markdown("## Available Transformations:") | |
| gr.Markdown("β’ π Day β Night conversion (CycleGAN)") | |
| gr.Markdown("β’ π¨ Photo β Japanese ukiyo-e art style (CycleGAN)") | |
| gr.Markdown("β’ π«οΈ Foggy β Clear weather transformation (CycleGAN)") | |
| gr.Markdown("β’ πΏ Summer β Winter seasonal atmosphere (CycleGAN)") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="π€ Upload Your Image", type="pil") | |
| analyze_btn = gr.Button("π Analyze Image (Optional)", variant="primary") | |
| with gr.Column(scale=1): | |
| analysis_output = gr.Markdown("## π Image Analysis Results", label="Analysis Results") | |
| recommendations = gr.CheckboxGroup( | |
| choices=[(transformation_labels[t], t) for t in available_transformations], | |
| label="π¨ Choose Transformations (All Available)", | |
| visible=True, | |
| value=None | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| apply_btn = gr.Button("π― Apply Selected Transfers", variant="secondary") | |
| with gr.Row(): | |
| status_output = gr.Textbox(label="π Applied Transformations", interactive=False) | |
| with gr.Row(): | |
| results_gallery = gr.Gallery( | |
| label="πΌοΈ Transformed Images", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=2, | |
| rows=2, | |
| height="auto" | |
| ) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=analyze_image, | |
| inputs=[image_input], | |
| outputs=[analysis_output, recommendations, results_gallery] | |
| ) | |
| apply_btn.click( | |
| fn=apply_transformations, | |
| inputs=[image_input, recommendations], | |
| outputs=[status_output, results_gallery] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |