import tensorflow as tf import numpy as np import cv2 import os import pickle from PIL import Image import gradio as gr # CRITICAL: Define the custom InstanceNormalization layer used in training class InstanceNormalization(tf.keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super(InstanceNormalization, self).__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): depth = input_shape[-1] self.scale = self.add_weight( shape=[depth], initializer=tf.random_normal_initializer(1., 0.02), trainable=True, name='scale' ) self.offset = self.add_weight( shape=[depth], initializer='zeros', trainable=True, name='offset' ) super().build(input_shape) def call(self, x): mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True) inv = tf.math.rsqrt(variance + self.epsilon) normalized = (x - mean) * inv return self.scale * normalized + self.offset def get_config(self): config = super().get_config() config.update({'epsilon': self.epsilon}) return config # Set up TensorFlow for compatibility os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.keras.mixed_precision.set_global_policy('float32') class MultiAttributeClassifier: def __init__(self): # Define categories for classification self.categories = ['content', 'style', 'time_of_day', 'weather'] self.models = {} self.encoders = {} self.gan_models = {} # Define custom objects FIRST (before using them) self.custom_objects = { 'InstanceNormalization': InstanceNormalization, 'tf': tf } # Load models self.load_classification_models() self.load_gan_models() def load_classification_models(self): """Load all classification models and encoders""" print("Loading classification models...") print(f"📂 Looking for models in: models/classification/") # First, let's see what's actually in the classification folder classification_path = "models/classification" if os.path.exists(classification_path): print(f"📁 Found classification directory") files = os.listdir(classification_path) print(f"📄 Available files: {files}") else: print(f"❌ Classification directory not found: {classification_path}") return for category in self.categories: try: # Load model from correct path model_path = f"models/classification/{category}_model.h5" if os.path.exists(model_path): print(f"🔍 Loading model: {model_path}") try: # Try normal loading first self.models[category] = tf.keras.models.load_model(model_path) except Exception as e1: print(f" ⚠️ Normal loading failed: {e1}") try: # Try with compile=False and custom objects self.models[category] = tf.keras.models.load_model( model_path, compile=False, custom_objects=self.custom_objects ) print(f" ✅ Loaded {category} with custom_objects") except Exception as e2: print(f" ❌ Failed to load {category}: {e2}") continue print(f"✅ Loaded {category} model ({os.path.getsize(model_path)/1024/1024:.1f} MB)") # Load encoder encoder_path = f"models/classification/{category}_encoder.pkl" if os.path.exists(encoder_path): with open(encoder_path, 'rb') as f: encoder_data = pickle.load(f) # Handle different encoder formats if hasattr(encoder_data, 'classes_'): # Standard LabelEncoder self.encoders[category] = encoder_data print(f"✅ Loaded {category} encoder (LabelEncoder) - {len(encoder_data.classes_)} classes") elif isinstance(encoder_data, dict): # Dict format - create a wrapper class EncoderWrapper: def __init__(self, class_dict): if 'classes_' in class_dict: self.classes_ = class_dict['classes_'] elif 'classes' in class_dict: self.classes_ = class_dict['classes'] else: # Try to extract classes from dict keys/values self.classes_ = list(class_dict.keys()) if class_dict else ['unknown'] self.encoders[category] = EncoderWrapper(encoder_data) print(f"✅ Loaded {category} encoder (Dict format) - {len(self.encoders[category].classes_)} classes") print(f" Classes: {self.encoders[category].classes_}") else: print(f"⚠️ Unknown encoder format for {category}: {type(encoder_data)}") print(f" Content preview: {str(encoder_data)[:200]}...") else: print(f"⚠️ {category} encoder not found at {encoder_path}") else: print(f"❌ {category} model not found at {model_path}") except Exception as e: print(f"❌ Failed to load {category}: {e}") import traceback traceback.print_exc() print(f"🎯 Successfully loaded {len(self.models)} classification models") def load_gan_models(self): """Load all GAN models for style transfer""" print("Loading GAN models...") # First, let's scan what's actually in the GAN folders gan_base_path = "models/gan" if os.path.exists(gan_base_path): print(f"📂 Found GAN models directory: {gan_base_path}") for folder in os.listdir(gan_base_path): folder_path = os.path.join(gan_base_path, folder) if os.path.isdir(folder_path): print(f"📁 GAN folder: {folder}") for file in os.listdir(folder_path): print(f" 📄 {file}") # Try multiple possible file name patterns gan_paths = { # Day/Night models - try multiple naming patterns 'day_to_night': [ 'models/gan/day_night/day_to_night_generator_final.keras', 'models/gan/day_night/day_to_night_generator.keras', 'models/gan/day_night/day_to_night.keras', 'models/gan/day_night/generator_day_to_night.keras' ], 'night_to_day': [ 'models/gan/day_night/night_to_day_generator_final.keras', 'models/gan/day_night/night_to_day_generator.keras', 'models/gan/day_night/night_to_day.keras', 'models/gan/day_night/generator_night_to_day.keras' ], # Foggy/Clear models 'foggy_to_clear': [ 'models/gan/foggy/foggy_to_normal_generator_final.keras', 'models/gan/foggy/foggy_to_clear_generator.keras', 'models/gan/foggy/foggy_to_clear.keras' ], 'clear_to_foggy': [ 'models/gan/foggy/normal_to_foggy_generator_final.keras', 'models/gan/foggy/clear_to_foggy_generator.keras', 'models/gan/foggy/clear_to_foggy.keras' ], # Japanese art models 'photo_to_japanese': [ 'models/gan/japanese/photo_to_ukiyoe_generator.keras', 'models/gan/japanese/photo_to_japanese_generator.keras', 'models/gan/japanese/photo_to_japanese.keras' ], 'japanese_to_photo': [ 'models/gan/japanese/ukiyoe_to_photo_generator.keras', 'models/gan/japanese/japanese_to_photo_generator.keras', 'models/gan/japanese/japanese_to_photo.keras' ], # Season models 'summer_to_winter': [ 'models/gan/summer_winter/summer_to_winter_generator_final.keras', 'models/gan/summer_winter/summer_to_winter_generator.keras', 'models/gan/summer_winter/summer_to_winter.keras' ], 'winter_to_summer': [ 'models/gan/summer_winter/winter_to_summer_generator_final.keras', 'models/gan/summer_winter/winter_to_summer_generator.keras', 'models/gan/summer_winter/winter_to_summer.keras' ] } for model_name, possible_paths in gan_paths.items(): model_loaded = False for model_path in possible_paths: try: if os.path.exists(model_path): print(f"🔍 Trying to load: {model_path}") # Try loading with different compatibility options try: # First try: Normal loading self.gan_models[model_name] = tf.keras.models.load_model(model_path) except Exception as e1: print(f" ⚠️ Normal loading failed: {e1}") try: # Second try: Load with compile=False (ignore training config) self.gan_models[model_name] = tf.keras.models.load_model(model_path, compile=False) print(f" ✅ Loaded with compile=False") except Exception as e2: print(f" ⚠️ compile=False failed: {e2}") try: # Third try: Load with custom objects (CRITICAL for your models) self.gan_models[model_name] = tf.keras.models.load_model( model_path, compile=False, custom_objects=self.custom_objects ) print(f" ✅ Loaded with custom_objects (InstanceNormalization)") except Exception as e3: print(f" ❌ All loading methods failed: {e3}") # Print the actual error for debugging print(f" 🔍 Error details: {str(e3)}") raise e3 print(f"✅ Loaded GAN: {model_name} from {model_path}") model_loaded = True break except Exception as e: print(f"❌ Failed to load {model_path}: {e}") continue if not model_loaded: print(f"⚠️ Could not load GAN model: {model_name}") # Let's also scan the actual directory to see what files exist folder_map = { 'day_to_night': 'day_night', 'night_to_day': 'day_night', 'foggy_to_clear': 'foggy', 'clear_to_foggy': 'foggy', 'photo_to_japanese': 'japanese', 'japanese_to_photo': 'japanese', 'summer_to_winter': 'summer_winter', 'winter_to_summer': 'summer_winter' } if model_name in folder_map: folder_path = f"models/gan/{folder_map[model_name]}" if os.path.exists(folder_path): print(f" 📁 Available files in {folder_path}:") for file in os.listdir(folder_path): if file.endswith(('.keras', '.h5')): print(f" 📄 {file}") print(f"🎯 Successfully loaded {len(self.gan_models)} GAN models") def preprocess_image(self, image): """Preprocess image for model input""" if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) # Resize image image = image.resize((224, 224)) # Convert to array and normalize img_array = np.array(image) if img_array.shape[-1] == 4: # RGBA img_array = img_array[:, :, :3] # Remove alpha channel # Normalize to [0, 1] img_array = img_array.astype(np.float32) / 255.0 # Add batch dimension img_array = np.expand_dims(img_array, axis=0) return img_array def predict_attributes(self, image): """Predict multiple attributes of an image""" preprocessed = self.preprocess_image(image) predictions = {} for category in self.categories: if category in self.models and category in self.encoders: try: # Get model prediction pred = self.models[category].predict(preprocessed, verbose=0) # Get predicted class predicted_class_idx = np.argmax(pred, axis=1)[0] confidence = float(np.max(pred)) # Get class name from encoder - handle different formats try: if hasattr(self.encoders[category], 'classes_'): classes = self.encoders[category].classes_ if predicted_class_idx < len(classes): class_name = classes[predicted_class_idx] else: class_name = f"class_{predicted_class_idx}" else: class_name = f"class_{predicted_class_idx}" except Exception as e: print(f"Error getting class name for {category}: {e}") class_name = f"class_{predicted_class_idx}" predictions[category] = { 'class': class_name, 'confidence': confidence } except Exception as e: print(f"Error predicting {category}: {e}") predictions[category] = { 'class': 'unknown', 'confidence': 0.0 } else: # Fallback predictions if models not loaded fallback_predictions = { 'content': {'class': 'outdoor', 'confidence': 0.6}, 'style': {'class': 'realistic', 'confidence': 0.7}, 'time_of_day': {'class': 'day', 'confidence': 0.8}, 'weather': {'class': 'clear', 'confidence': 0.8} } predictions[category] = fallback_predictions.get(category, {'class': 'unknown', 'confidence': 0.0}) return predictions def apply_style_transfer(self, image, transformation): """Apply style transfer using trained GAN models""" if transformation not in self.gan_models: return None, f"GAN model '{transformation}' not available" try: # Preprocess image for GAN (256x256, normalized to [-1, 1]) if isinstance(image, str): img = Image.open(image) elif isinstance(image, np.ndarray): img = Image.fromarray(image) else: img = image # Resize to 256x256 for GAN img = img.resize((256, 256)) img_array = np.array(img) if img_array.shape[-1] == 4: # RGBA img_array = img_array[:, :, :3] # Remove alpha channel # Normalize to [-1, 1] for GAN img_array = (img_array.astype(np.float32) / 127.5) - 1.0 img_array = np.expand_dims(img_array, axis=0) # Apply transformation model = self.gan_models[transformation] generated = model.predict(img_array, verbose=0) # Denormalize and convert back to image generated = (generated[0] + 1.0) * 127.5 generated = np.clip(generated, 0, 255).astype(np.uint8) return generated, "Transformation completed!" except Exception as e: print(f"Error in style transfer: {e}") return None, f"Error: {str(e)}" def get_style_recommendations(self, predictions): """Get style transfer recommendations based on predictions""" recommendations = [] # Time-based recommendations if 'time_of_day' in predictions: time_pred = predictions['time_of_day'] if time_pred['class'] == 'day' and time_pred['confidence'] > 0.7: recommendations.append({ 'transformation': 'day_to_night', 'confidence': time_pred['confidence'], 'description': f"Transform scene to night with {time_pred['confidence']*100:.0f}% confidence" }) elif time_pred['class'] == 'night' and time_pred['confidence'] > 0.7: recommendations.append({ 'transformation': 'night_to_day', 'confidence': time_pred['confidence'], 'description': f"Transform scene to day with {time_pred['confidence']*100:.0f}% confidence" }) # Weather-based recommendations if 'weather' in predictions: weather_pred = predictions['weather'] if weather_pred['class'] == 'clear' and weather_pred['confidence'] > 0.6: recommendations.append({ 'transformation': 'clear_to_foggy', 'confidence': weather_pred['confidence'], 'description': f"Add fog atmosphere with {weather_pred['confidence']*100:.0f}% confidence" }) elif weather_pred['class'] == 'foggy' and weather_pred['confidence'] > 0.6: recommendations.append({ 'transformation': 'foggy_to_clear', 'confidence': weather_pred['confidence'], 'description': f"Clear fog from scene with {weather_pred['confidence']*100:.0f}% confidence" }) # Content-based recommendations if 'content' in predictions: content_pred = predictions['content'] if content_pred['class'] in ['outdoor', 'landscape'] and content_pred['confidence'] > 0.6: recommendations.extend([ { 'transformation': 'summer_to_winter', 'confidence': 0.8, 'description': f"Transform scene to winter with snow and cold atmosphere" }, { 'transformation': 'winter_to_summer', 'confidence': 0.8, 'description': f"Transform scene to summer with warm, lush atmosphere" } ]) # Style-based recommendations if 'style' in predictions: style_pred = predictions['style'] if style_pred['class'] == 'realistic' and style_pred['confidence'] > 0.6: recommendations.append({ 'transformation': 'photo_to_japanese', 'confidence': style_pred['confidence'], 'description': f"Transform to Japanese ukiyo-e art style" }) return recommendations # Initialize classifier globally print("🚀 Starting StyleTransfer App...") classifier = MultiAttributeClassifier() print(f"🎯 Initialization complete!") print(f" 📊 Classification models loaded: {len(classifier.models)}") print(f" 🎨 GAN models loaded: {len(classifier.gan_models)}") if len(classifier.models) > 0: print(f" ✅ Available categories: {list(classifier.models.keys())}") if len(classifier.gan_models) > 0: print(f" ✅ Available transformations: {list(classifier.gan_models.keys())}") print("="*50) def analyze_image(image): """Analyze uploaded image and provide style recommendations""" if image is None: choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] return "Please upload an image first.", gr.update(choices=choices_with_labels, value=None, visible=True), [] try: # Get predictions for all attributes predictions = classifier.predict_attributes(image) # Format analysis results analysis_text = "## 🔍 Image Analysis Results\n\n" for category, pred in predictions.items(): confidence_pct = pred['confidence'] * 100 analysis_text += f"**{category.replace('_', ' ').title()}:** {pred['class']} ({confidence_pct:.1f}% confidence)\n\n" # Get style recommendations recommendations = classifier.get_style_recommendations(predictions) # Format recommendations for display if recommendations: analysis_text += "## 🎨 AI Suggestions\n\n" for rec in recommendations: analysis_text += f"**{rec['transformation'].replace('_', ' → ').title()}** ({rec['confidence']*100:.0f}%) {rec['description']}\n\n" else: analysis_text += "## 🎨 AI Suggestions\n\nNo specific recommendations - but feel free to try any transformation!\n\n" analysis_text += "---\n**Choose any transformation(s) below - you're not limited to the suggestions!**" # Always return ALL available transformations, regardless of analysis choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] return analysis_text, gr.update(choices=choices_with_labels, value=None, visible=True), [] except Exception as e: print(f"Error in analysis: {e}") import traceback traceback.print_exc() # Even if analysis fails, still show all transformations choices_with_labels = [(transformation_labels[t], t) for t in available_transformations] return f"Error analyzing image: {str(e)}\n\n**All transformations still available below:**", gr.update(choices=choices_with_labels, value=None, visible=True), [] def apply_transformations(image, selected_transformations): """Apply selected style transformations""" if image is None: return "Please upload an image first.", [] if not selected_transformations: return "Please select at least one transformation.", [] results = [] status_messages = [] for transformation in selected_transformations: try: transformed_img, message = classifier.apply_style_transfer(image, transformation) if transformed_img is not None: results.append(transformed_img) status_messages.append(f"✅ {transformation.replace('_', ' → ').title()}: {message}") else: status_messages.append(f"❌ {transformation.replace('_', ' → ').title()}: {message}") except Exception as e: status_messages.append(f"❌ {transformation}: Error - {str(e)}") status_text = "\n".join(status_messages) return status_text, results # Available transformations for manual selection - show user-friendly names available_transformations = [ "day_to_night", "night_to_day", "clear_to_foggy", "foggy_to_clear", "photo_to_japanese", "japanese_to_photo", "summer_to_winter", "winter_to_summer" ] # User-friendly transformation names transformation_labels = { "day_to_night": "🌅→🌙 Day to Night", "night_to_day": "🌙→🌅 Night to Day", "clear_to_foggy": "☀️→🌫️ Clear to Foggy", "foggy_to_clear": "🌫️→☀️ Foggy to Clear", "photo_to_japanese": "📷→🎨 Photo to Japanese Art", "japanese_to_photo": "🎨→📷 Japanese Art to Photo", "summer_to_winter": "🌿→❄️ Summer to Winter", "winter_to_summer": "❄️→🌿 Winter to Summer" } # Create Gradio interface with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 Intelligent Multi-Attribute Style Transfer") gr.Markdown("Upload an image and our AI will analyze it to provide smart suggestions - **but you can choose ANY transformation you want!**") gr.Markdown("💡 **Tip:** You can skip analysis and apply transformations directly!") # Show available transformations gr.Markdown("## Available Transformations:") gr.Markdown("• 🌅 Day ↔ Night conversion (CycleGAN)") gr.Markdown("• 🎨 Photo ↔ Japanese ukiyo-e art style (CycleGAN)") gr.Markdown("• 🌫️ Foggy ↔ Clear weather transformation (CycleGAN)") gr.Markdown("• 🌿 Summer ↔ Winter seasonal atmosphere (CycleGAN)") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(label="📤 Upload Your Image", type="pil") analyze_btn = gr.Button("🔍 Analyze Image (Optional)", variant="primary") with gr.Column(scale=1): analysis_output = gr.Markdown("## 📊 Image Analysis Results", label="Analysis Results") recommendations = gr.CheckboxGroup( choices=[(transformation_labels[t], t) for t in available_transformations], label="🎨 Choose Transformations (All Available)", visible=True, value=None ) with gr.Row(): with gr.Column(): apply_btn = gr.Button("🎯 Apply Selected Transfers", variant="secondary") with gr.Row(): status_output = gr.Textbox(label="📋 Applied Transformations", interactive=False) with gr.Row(): results_gallery = gr.Gallery( label="🖼️ Transformed Images", show_label=True, elem_id="gallery", columns=2, rows=2, height="auto" ) # Event handlers analyze_btn.click( fn=analyze_image, inputs=[image_input], outputs=[analysis_output, recommendations, results_gallery] ) apply_btn.click( fn=apply_transformations, inputs=[image_input, recommendations], outputs=[status_output, results_gallery] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)