StyleTransfer / app.py
dannyroxas's picture
Update app.py
12a6b18 verified
import tensorflow as tf
import numpy as np
import cv2
import os
import pickle
from PIL import Image
import gradio as gr
# CRITICAL: Define the custom InstanceNormalization layer used in training
class InstanceNormalization(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super(InstanceNormalization, self).__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
depth = input_shape[-1]
self.scale = self.add_weight(
shape=[depth],
initializer=tf.random_normal_initializer(1., 0.02),
trainable=True,
name='scale'
)
self.offset = self.add_weight(
shape=[depth],
initializer='zeros',
trainable=True,
name='offset'
)
super().build(input_shape)
def call(self, x):
mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
inv = tf.math.rsqrt(variance + self.epsilon)
normalized = (x - mean) * inv
return self.scale * normalized + self.offset
def get_config(self):
config = super().get_config()
config.update({'epsilon': self.epsilon})
return config
# Set up TensorFlow for compatibility
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.keras.mixed_precision.set_global_policy('float32')
class MultiAttributeClassifier:
def __init__(self):
# Define categories for classification
self.categories = ['content', 'style', 'time_of_day', 'weather']
self.models = {}
self.encoders = {}
self.gan_models = {}
# Define custom objects FIRST (before using them)
self.custom_objects = {
'InstanceNormalization': InstanceNormalization,
'tf': tf
}
# Load models
self.load_classification_models()
self.load_gan_models()
def load_classification_models(self):
"""Load all classification models and encoders"""
print("Loading classification models...")
print(f"πŸ“‚ Looking for models in: models/classification/")
# First, let's see what's actually in the classification folder
classification_path = "models/classification"
if os.path.exists(classification_path):
print(f"πŸ“ Found classification directory")
files = os.listdir(classification_path)
print(f"πŸ“„ Available files: {files}")
else:
print(f"❌ Classification directory not found: {classification_path}")
return
for category in self.categories:
try:
# Load model from correct path
model_path = f"models/classification/{category}_model.h5"
if os.path.exists(model_path):
print(f"πŸ” Loading model: {model_path}")
try:
# Try normal loading first
self.models[category] = tf.keras.models.load_model(model_path)
except Exception as e1:
print(f" ⚠️ Normal loading failed: {e1}")
try:
# Try with compile=False and custom objects
self.models[category] = tf.keras.models.load_model(
model_path,
compile=False,
custom_objects=self.custom_objects
)
print(f" βœ… Loaded {category} with custom_objects")
except Exception as e2:
print(f" ❌ Failed to load {category}: {e2}")
continue
print(f"βœ… Loaded {category} model ({os.path.getsize(model_path)/1024/1024:.1f} MB)")
# Load encoder
encoder_path = f"models/classification/{category}_encoder.pkl"
if os.path.exists(encoder_path):
with open(encoder_path, 'rb') as f:
encoder_data = pickle.load(f)
# Handle different encoder formats
if hasattr(encoder_data, 'classes_'):
# Standard LabelEncoder
self.encoders[category] = encoder_data
print(f"βœ… Loaded {category} encoder (LabelEncoder) - {len(encoder_data.classes_)} classes")
elif isinstance(encoder_data, dict):
# Dict format - create a wrapper
class EncoderWrapper:
def __init__(self, class_dict):
if 'classes_' in class_dict:
self.classes_ = class_dict['classes_']
elif 'classes' in class_dict:
self.classes_ = class_dict['classes']
else:
# Try to extract classes from dict keys/values
self.classes_ = list(class_dict.keys()) if class_dict else ['unknown']
self.encoders[category] = EncoderWrapper(encoder_data)
print(f"βœ… Loaded {category} encoder (Dict format) - {len(self.encoders[category].classes_)} classes")
print(f" Classes: {self.encoders[category].classes_}")
else:
print(f"⚠️ Unknown encoder format for {category}: {type(encoder_data)}")
print(f" Content preview: {str(encoder_data)[:200]}...")
else:
print(f"⚠️ {category} encoder not found at {encoder_path}")
else:
print(f"❌ {category} model not found at {model_path}")
except Exception as e:
print(f"❌ Failed to load {category}: {e}")
import traceback
traceback.print_exc()
print(f"🎯 Successfully loaded {len(self.models)} classification models")
def load_gan_models(self):
"""Load all GAN models for style transfer"""
print("Loading GAN models...")
# First, let's scan what's actually in the GAN folders
gan_base_path = "models/gan"
if os.path.exists(gan_base_path):
print(f"πŸ“‚ Found GAN models directory: {gan_base_path}")
for folder in os.listdir(gan_base_path):
folder_path = os.path.join(gan_base_path, folder)
if os.path.isdir(folder_path):
print(f"πŸ“ GAN folder: {folder}")
for file in os.listdir(folder_path):
print(f" πŸ“„ {file}")
# Try multiple possible file name patterns
gan_paths = {
# Day/Night models - try multiple naming patterns
'day_to_night': [
'models/gan/day_night/day_to_night_generator_final.keras',
'models/gan/day_night/day_to_night_generator.keras',
'models/gan/day_night/day_to_night.keras',
'models/gan/day_night/generator_day_to_night.keras'
],
'night_to_day': [
'models/gan/day_night/night_to_day_generator_final.keras',
'models/gan/day_night/night_to_day_generator.keras',
'models/gan/day_night/night_to_day.keras',
'models/gan/day_night/generator_night_to_day.keras'
],
# Foggy/Clear models
'foggy_to_clear': [
'models/gan/foggy/foggy_to_normal_generator_final.keras',
'models/gan/foggy/foggy_to_clear_generator.keras',
'models/gan/foggy/foggy_to_clear.keras'
],
'clear_to_foggy': [
'models/gan/foggy/normal_to_foggy_generator_final.keras',
'models/gan/foggy/clear_to_foggy_generator.keras',
'models/gan/foggy/clear_to_foggy.keras'
],
# Japanese art models
'photo_to_japanese': [
'models/gan/japanese/photo_to_ukiyoe_generator.keras',
'models/gan/japanese/photo_to_japanese_generator.keras',
'models/gan/japanese/photo_to_japanese.keras'
],
'japanese_to_photo': [
'models/gan/japanese/ukiyoe_to_photo_generator.keras',
'models/gan/japanese/japanese_to_photo_generator.keras',
'models/gan/japanese/japanese_to_photo.keras'
],
# Season models
'summer_to_winter': [
'models/gan/summer_winter/summer_to_winter_generator_final.keras',
'models/gan/summer_winter/summer_to_winter_generator.keras',
'models/gan/summer_winter/summer_to_winter.keras'
],
'winter_to_summer': [
'models/gan/summer_winter/winter_to_summer_generator_final.keras',
'models/gan/summer_winter/winter_to_summer_generator.keras',
'models/gan/summer_winter/winter_to_summer.keras'
]
}
for model_name, possible_paths in gan_paths.items():
model_loaded = False
for model_path in possible_paths:
try:
if os.path.exists(model_path):
print(f"πŸ” Trying to load: {model_path}")
# Try loading with different compatibility options
try:
# First try: Normal loading
self.gan_models[model_name] = tf.keras.models.load_model(model_path)
except Exception as e1:
print(f" ⚠️ Normal loading failed: {e1}")
try:
# Second try: Load with compile=False (ignore training config)
self.gan_models[model_name] = tf.keras.models.load_model(model_path, compile=False)
print(f" βœ… Loaded with compile=False")
except Exception as e2:
print(f" ⚠️ compile=False failed: {e2}")
try:
# Third try: Load with custom objects (CRITICAL for your models)
self.gan_models[model_name] = tf.keras.models.load_model(
model_path,
compile=False,
custom_objects=self.custom_objects
)
print(f" βœ… Loaded with custom_objects (InstanceNormalization)")
except Exception as e3:
print(f" ❌ All loading methods failed: {e3}")
# Print the actual error for debugging
print(f" πŸ” Error details: {str(e3)}")
raise e3
print(f"βœ… Loaded GAN: {model_name} from {model_path}")
model_loaded = True
break
except Exception as e:
print(f"❌ Failed to load {model_path}: {e}")
continue
if not model_loaded:
print(f"⚠️ Could not load GAN model: {model_name}")
# Let's also scan the actual directory to see what files exist
folder_map = {
'day_to_night': 'day_night',
'night_to_day': 'day_night',
'foggy_to_clear': 'foggy',
'clear_to_foggy': 'foggy',
'photo_to_japanese': 'japanese',
'japanese_to_photo': 'japanese',
'summer_to_winter': 'summer_winter',
'winter_to_summer': 'summer_winter'
}
if model_name in folder_map:
folder_path = f"models/gan/{folder_map[model_name]}"
if os.path.exists(folder_path):
print(f" πŸ“ Available files in {folder_path}:")
for file in os.listdir(folder_path):
if file.endswith(('.keras', '.h5')):
print(f" πŸ“„ {file}")
print(f"🎯 Successfully loaded {len(self.gan_models)} GAN models")
def preprocess_image(self, image):
"""Preprocess image for model input"""
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Resize image
image = image.resize((224, 224))
# Convert to array and normalize
img_array = np.array(image)
if img_array.shape[-1] == 4: # RGBA
img_array = img_array[:, :, :3] # Remove alpha channel
# Normalize to [0, 1]
img_array = img_array.astype(np.float32) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_attributes(self, image):
"""Predict multiple attributes of an image"""
preprocessed = self.preprocess_image(image)
predictions = {}
for category in self.categories:
if category in self.models and category in self.encoders:
try:
# Get model prediction
pred = self.models[category].predict(preprocessed, verbose=0)
# Get predicted class
predicted_class_idx = np.argmax(pred, axis=1)[0]
confidence = float(np.max(pred))
# Get class name from encoder - handle different formats
try:
if hasattr(self.encoders[category], 'classes_'):
classes = self.encoders[category].classes_
if predicted_class_idx < len(classes):
class_name = classes[predicted_class_idx]
else:
class_name = f"class_{predicted_class_idx}"
else:
class_name = f"class_{predicted_class_idx}"
except Exception as e:
print(f"Error getting class name for {category}: {e}")
class_name = f"class_{predicted_class_idx}"
predictions[category] = {
'class': class_name,
'confidence': confidence
}
except Exception as e:
print(f"Error predicting {category}: {e}")
predictions[category] = {
'class': 'unknown',
'confidence': 0.0
}
else:
# Fallback predictions if models not loaded
fallback_predictions = {
'content': {'class': 'outdoor', 'confidence': 0.6},
'style': {'class': 'realistic', 'confidence': 0.7},
'time_of_day': {'class': 'day', 'confidence': 0.8},
'weather': {'class': 'clear', 'confidence': 0.8}
}
predictions[category] = fallback_predictions.get(category, {'class': 'unknown', 'confidence': 0.0})
return predictions
def apply_style_transfer(self, image, transformation):
"""Apply style transfer using trained GAN models"""
if transformation not in self.gan_models:
return None, f"GAN model '{transformation}' not available"
try:
# Preprocess image for GAN (256x256, normalized to [-1, 1])
if isinstance(image, str):
img = Image.open(image)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image)
else:
img = image
# Resize to 256x256 for GAN
img = img.resize((256, 256))
img_array = np.array(img)
if img_array.shape[-1] == 4: # RGBA
img_array = img_array[:, :, :3] # Remove alpha channel
# Normalize to [-1, 1] for GAN
img_array = (img_array.astype(np.float32) / 127.5) - 1.0
img_array = np.expand_dims(img_array, axis=0)
# Apply transformation
model = self.gan_models[transformation]
generated = model.predict(img_array, verbose=0)
# Denormalize and convert back to image
generated = (generated[0] + 1.0) * 127.5
generated = np.clip(generated, 0, 255).astype(np.uint8)
return generated, "Transformation completed!"
except Exception as e:
print(f"Error in style transfer: {e}")
return None, f"Error: {str(e)}"
def get_style_recommendations(self, predictions):
"""Get style transfer recommendations based on predictions"""
recommendations = []
# Time-based recommendations
if 'time_of_day' in predictions:
time_pred = predictions['time_of_day']
if time_pred['class'] == 'day' and time_pred['confidence'] > 0.7:
recommendations.append({
'transformation': 'day_to_night',
'confidence': time_pred['confidence'],
'description': f"Transform scene to night with {time_pred['confidence']*100:.0f}% confidence"
})
elif time_pred['class'] == 'night' and time_pred['confidence'] > 0.7:
recommendations.append({
'transformation': 'night_to_day',
'confidence': time_pred['confidence'],
'description': f"Transform scene to day with {time_pred['confidence']*100:.0f}% confidence"
})
# Weather-based recommendations
if 'weather' in predictions:
weather_pred = predictions['weather']
if weather_pred['class'] == 'clear' and weather_pred['confidence'] > 0.6:
recommendations.append({
'transformation': 'clear_to_foggy',
'confidence': weather_pred['confidence'],
'description': f"Add fog atmosphere with {weather_pred['confidence']*100:.0f}% confidence"
})
elif weather_pred['class'] == 'foggy' and weather_pred['confidence'] > 0.6:
recommendations.append({
'transformation': 'foggy_to_clear',
'confidence': weather_pred['confidence'],
'description': f"Clear fog from scene with {weather_pred['confidence']*100:.0f}% confidence"
})
# Content-based recommendations
if 'content' in predictions:
content_pred = predictions['content']
if content_pred['class'] in ['outdoor', 'landscape'] and content_pred['confidence'] > 0.6:
recommendations.extend([
{
'transformation': 'summer_to_winter',
'confidence': 0.8,
'description': f"Transform scene to winter with snow and cold atmosphere"
},
{
'transformation': 'winter_to_summer',
'confidence': 0.8,
'description': f"Transform scene to summer with warm, lush atmosphere"
}
])
# Style-based recommendations
if 'style' in predictions:
style_pred = predictions['style']
if style_pred['class'] == 'realistic' and style_pred['confidence'] > 0.6:
recommendations.append({
'transformation': 'photo_to_japanese',
'confidence': style_pred['confidence'],
'description': f"Transform to Japanese ukiyo-e art style"
})
return recommendations
# Initialize classifier globally
print("πŸš€ Starting StyleTransfer App...")
classifier = MultiAttributeClassifier()
print(f"🎯 Initialization complete!")
print(f" πŸ“Š Classification models loaded: {len(classifier.models)}")
print(f" 🎨 GAN models loaded: {len(classifier.gan_models)}")
if len(classifier.models) > 0:
print(f" βœ… Available categories: {list(classifier.models.keys())}")
if len(classifier.gan_models) > 0:
print(f" βœ… Available transformations: {list(classifier.gan_models.keys())}")
print("="*50)
def analyze_image(image):
"""Analyze uploaded image and provide style recommendations"""
if image is None:
choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
return "Please upload an image first.", gr.update(choices=choices_with_labels, value=None, visible=True), []
try:
# Get predictions for all attributes
predictions = classifier.predict_attributes(image)
# Format analysis results
analysis_text = "## πŸ” Image Analysis Results\n\n"
for category, pred in predictions.items():
confidence_pct = pred['confidence'] * 100
analysis_text += f"**{category.replace('_', ' ').title()}:** {pred['class']} ({confidence_pct:.1f}% confidence)\n\n"
# Get style recommendations
recommendations = classifier.get_style_recommendations(predictions)
# Format recommendations for display
if recommendations:
analysis_text += "## 🎨 AI Suggestions\n\n"
for rec in recommendations:
analysis_text += f"**{rec['transformation'].replace('_', ' β†’ ').title()}** ({rec['confidence']*100:.0f}%) {rec['description']}\n\n"
else:
analysis_text += "## 🎨 AI Suggestions\n\nNo specific recommendations - but feel free to try any transformation!\n\n"
analysis_text += "---\n**Choose any transformation(s) below - you're not limited to the suggestions!**"
# Always return ALL available transformations, regardless of analysis
choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
return analysis_text, gr.update(choices=choices_with_labels, value=None, visible=True), []
except Exception as e:
print(f"Error in analysis: {e}")
import traceback
traceback.print_exc()
# Even if analysis fails, still show all transformations
choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
return f"Error analyzing image: {str(e)}\n\n**All transformations still available below:**", gr.update(choices=choices_with_labels, value=None, visible=True), []
def apply_transformations(image, selected_transformations):
"""Apply selected style transformations"""
if image is None:
return "Please upload an image first.", []
if not selected_transformations:
return "Please select at least one transformation.", []
results = []
status_messages = []
for transformation in selected_transformations:
try:
transformed_img, message = classifier.apply_style_transfer(image, transformation)
if transformed_img is not None:
results.append(transformed_img)
status_messages.append(f"βœ… {transformation.replace('_', ' β†’ ').title()}: {message}")
else:
status_messages.append(f"❌ {transformation.replace('_', ' β†’ ').title()}: {message}")
except Exception as e:
status_messages.append(f"❌ {transformation}: Error - {str(e)}")
status_text = "\n".join(status_messages)
return status_text, results
# Available transformations for manual selection - show user-friendly names
available_transformations = [
"day_to_night", "night_to_day",
"clear_to_foggy", "foggy_to_clear",
"photo_to_japanese", "japanese_to_photo",
"summer_to_winter", "winter_to_summer"
]
# User-friendly transformation names
transformation_labels = {
"day_to_night": "πŸŒ…β†’πŸŒ™ Day to Night",
"night_to_day": "πŸŒ™β†’πŸŒ… Night to Day",
"clear_to_foggy": "β˜€οΈβ†’πŸŒ«οΈ Clear to Foggy",
"foggy_to_clear": "πŸŒ«οΈβ†’β˜€οΈ Foggy to Clear",
"photo_to_japanese": "πŸ“·β†’πŸŽ¨ Photo to Japanese Art",
"japanese_to_photo": "πŸŽ¨β†’πŸ“· Japanese Art to Photo",
"summer_to_winter": "πŸŒΏβ†’β„οΈ Summer to Winter",
"winter_to_summer": "β„οΈβ†’πŸŒΏ Winter to Summer"
}
# Create Gradio interface
with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Intelligent Multi-Attribute Style Transfer")
gr.Markdown("Upload an image and our AI will analyze it to provide smart suggestions - **but you can choose ANY transformation you want!**")
gr.Markdown("πŸ’‘ **Tip:** You can skip analysis and apply transformations directly!")
# Show available transformations
gr.Markdown("## Available Transformations:")
gr.Markdown("β€’ πŸŒ… Day ↔ Night conversion (CycleGAN)")
gr.Markdown("β€’ 🎨 Photo ↔ Japanese ukiyo-e art style (CycleGAN)")
gr.Markdown("β€’ 🌫️ Foggy ↔ Clear weather transformation (CycleGAN)")
gr.Markdown("β€’ 🌿 Summer ↔ Winter seasonal atmosphere (CycleGAN)")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="πŸ“€ Upload Your Image", type="pil")
analyze_btn = gr.Button("πŸ” Analyze Image (Optional)", variant="primary")
with gr.Column(scale=1):
analysis_output = gr.Markdown("## πŸ“Š Image Analysis Results", label="Analysis Results")
recommendations = gr.CheckboxGroup(
choices=[(transformation_labels[t], t) for t in available_transformations],
label="🎨 Choose Transformations (All Available)",
visible=True,
value=None
)
with gr.Row():
with gr.Column():
apply_btn = gr.Button("🎯 Apply Selected Transfers", variant="secondary")
with gr.Row():
status_output = gr.Textbox(label="πŸ“‹ Applied Transformations", interactive=False)
with gr.Row():
results_gallery = gr.Gallery(
label="πŸ–ΌοΈ Transformed Images",
show_label=True,
elem_id="gallery",
columns=2,
rows=2,
height="auto"
)
# Event handlers
analyze_btn.click(
fn=analyze_image,
inputs=[image_input],
outputs=[analysis_output, recommendations, results_gallery]
)
apply_btn.click(
fn=apply_transformations,
inputs=[image_input, recommendations],
outputs=[status_output, results_gallery]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)