dannyroxas commited on
Commit
12a6b18
Β·
verified Β·
1 Parent(s): 24a42c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -18
app.py CHANGED
@@ -105,8 +105,31 @@ class MultiAttributeClassifier:
105
  encoder_path = f"models/classification/{category}_encoder.pkl"
106
  if os.path.exists(encoder_path):
107
  with open(encoder_path, 'rb') as f:
108
- self.encoders[category] = pickle.load(f)
109
- print(f"βœ… Loaded {category} encoder")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  else:
111
  print(f"⚠️ {category} encoder not found at {encoder_path}")
112
  else:
@@ -287,8 +310,19 @@ class MultiAttributeClassifier:
287
  predicted_class_idx = np.argmax(pred, axis=1)[0]
288
  confidence = float(np.max(pred))
289
 
290
- # Get class name from encoder
291
- class_name = self.encoders[category].classes_[predicted_class_idx]
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  predictions[category] = {
294
  'class': class_name,
@@ -431,7 +465,8 @@ print("="*50)
431
  def analyze_image(image):
432
  """Analyze uploaded image and provide style recommendations"""
433
  if image is None:
434
- return "Please upload an image first.", "", []
 
435
 
436
  try:
437
  # Get predictions for all attributes
@@ -446,23 +481,27 @@ def analyze_image(image):
446
  # Get style recommendations
447
  recommendations = classifier.get_style_recommendations(predictions)
448
 
449
- # Format recommendations for display
450
- rec_choices = []
451
  if recommendations:
452
- analysis_text += "## 🎨 Available Style Transfers\n\n"
453
  for rec in recommendations:
454
  analysis_text += f"**{rec['transformation'].replace('_', ' β†’ ').title()}** ({rec['confidence']*100:.0f}%) {rec['description']}\n\n"
455
- rec_choices.append(rec['transformation'])
456
  else:
457
- analysis_text += "No specific style transfer recommendations based on this image.\n"
458
 
459
- return analysis_text, gr.update(choices=rec_choices, value=None, visible=True), []
 
 
 
 
460
 
461
  except Exception as e:
462
  print(f"Error in analysis: {e}")
463
  import traceback
464
  traceback.print_exc()
465
- return f"Error analyzing image: {str(e)}", gr.update(visible=False), []
 
 
466
 
467
  def apply_transformations(image, selected_transformations):
468
  """Apply selected style transformations"""
@@ -489,7 +528,7 @@ def apply_transformations(image, selected_transformations):
489
  status_text = "\n".join(status_messages)
490
  return status_text, results
491
 
492
- # Available transformations for manual selection
493
  available_transformations = [
494
  "day_to_night", "night_to_day",
495
  "clear_to_foggy", "foggy_to_clear",
@@ -497,10 +536,23 @@ available_transformations = [
497
  "summer_to_winter", "winter_to_summer"
498
  ]
499
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  # Create Gradio interface
501
  with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.themes.Soft()) as demo:
502
  gr.Markdown("# 🎨 Intelligent Multi-Attribute Style Transfer")
503
- gr.Markdown("Upload an image and our AI will analyze multiple attributes (content, style, time, weather) and suggest relevant style transfers using trained GAN models!")
 
504
 
505
  # Show available transformations
506
  gr.Markdown("## Available Transformations:")
@@ -508,18 +560,20 @@ with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.them
508
  gr.Markdown("β€’ 🎨 Photo ↔ Japanese ukiyo-e art style (CycleGAN)")
509
  gr.Markdown("β€’ 🌫️ Foggy ↔ Clear weather transformation (CycleGAN)")
510
  gr.Markdown("β€’ 🌿 Summer ↔ Winter seasonal atmosphere (CycleGAN)")
 
511
 
512
  with gr.Row():
513
  with gr.Column(scale=1):
514
  image_input = gr.Image(label="πŸ“€ Upload Your Image", type="pil")
515
- analyze_btn = gr.Button("πŸ” Analyze Image", variant="primary")
516
 
517
  with gr.Column(scale=1):
518
  analysis_output = gr.Markdown("## πŸ“Š Image Analysis Results", label="Analysis Results")
519
  recommendations = gr.CheckboxGroup(
520
- choices=[],
521
- label="🎨 Available Style Transfers",
522
- visible=False
 
523
  )
524
 
525
  with gr.Row():
 
105
  encoder_path = f"models/classification/{category}_encoder.pkl"
106
  if os.path.exists(encoder_path):
107
  with open(encoder_path, 'rb') as f:
108
+ encoder_data = pickle.load(f)
109
+
110
+ # Handle different encoder formats
111
+ if hasattr(encoder_data, 'classes_'):
112
+ # Standard LabelEncoder
113
+ self.encoders[category] = encoder_data
114
+ print(f"βœ… Loaded {category} encoder (LabelEncoder) - {len(encoder_data.classes_)} classes")
115
+ elif isinstance(encoder_data, dict):
116
+ # Dict format - create a wrapper
117
+ class EncoderWrapper:
118
+ def __init__(self, class_dict):
119
+ if 'classes_' in class_dict:
120
+ self.classes_ = class_dict['classes_']
121
+ elif 'classes' in class_dict:
122
+ self.classes_ = class_dict['classes']
123
+ else:
124
+ # Try to extract classes from dict keys/values
125
+ self.classes_ = list(class_dict.keys()) if class_dict else ['unknown']
126
+
127
+ self.encoders[category] = EncoderWrapper(encoder_data)
128
+ print(f"βœ… Loaded {category} encoder (Dict format) - {len(self.encoders[category].classes_)} classes")
129
+ print(f" Classes: {self.encoders[category].classes_}")
130
+ else:
131
+ print(f"⚠️ Unknown encoder format for {category}: {type(encoder_data)}")
132
+ print(f" Content preview: {str(encoder_data)[:200]}...")
133
  else:
134
  print(f"⚠️ {category} encoder not found at {encoder_path}")
135
  else:
 
310
  predicted_class_idx = np.argmax(pred, axis=1)[0]
311
  confidence = float(np.max(pred))
312
 
313
+ # Get class name from encoder - handle different formats
314
+ try:
315
+ if hasattr(self.encoders[category], 'classes_'):
316
+ classes = self.encoders[category].classes_
317
+ if predicted_class_idx < len(classes):
318
+ class_name = classes[predicted_class_idx]
319
+ else:
320
+ class_name = f"class_{predicted_class_idx}"
321
+ else:
322
+ class_name = f"class_{predicted_class_idx}"
323
+ except Exception as e:
324
+ print(f"Error getting class name for {category}: {e}")
325
+ class_name = f"class_{predicted_class_idx}"
326
 
327
  predictions[category] = {
328
  'class': class_name,
 
465
  def analyze_image(image):
466
  """Analyze uploaded image and provide style recommendations"""
467
  if image is None:
468
+ choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
469
+ return "Please upload an image first.", gr.update(choices=choices_with_labels, value=None, visible=True), []
470
 
471
  try:
472
  # Get predictions for all attributes
 
481
  # Get style recommendations
482
  recommendations = classifier.get_style_recommendations(predictions)
483
 
484
+ # Format recommendations for display
 
485
  if recommendations:
486
+ analysis_text += "## 🎨 AI Suggestions\n\n"
487
  for rec in recommendations:
488
  analysis_text += f"**{rec['transformation'].replace('_', ' β†’ ').title()}** ({rec['confidence']*100:.0f}%) {rec['description']}\n\n"
 
489
  else:
490
+ analysis_text += "## 🎨 AI Suggestions\n\nNo specific recommendations - but feel free to try any transformation!\n\n"
491
 
492
+ analysis_text += "---\n**Choose any transformation(s) below - you're not limited to the suggestions!**"
493
+
494
+ # Always return ALL available transformations, regardless of analysis
495
+ choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
496
+ return analysis_text, gr.update(choices=choices_with_labels, value=None, visible=True), []
497
 
498
  except Exception as e:
499
  print(f"Error in analysis: {e}")
500
  import traceback
501
  traceback.print_exc()
502
+ # Even if analysis fails, still show all transformations
503
+ choices_with_labels = [(transformation_labels[t], t) for t in available_transformations]
504
+ return f"Error analyzing image: {str(e)}\n\n**All transformations still available below:**", gr.update(choices=choices_with_labels, value=None, visible=True), []
505
 
506
  def apply_transformations(image, selected_transformations):
507
  """Apply selected style transformations"""
 
528
  status_text = "\n".join(status_messages)
529
  return status_text, results
530
 
531
+ # Available transformations for manual selection - show user-friendly names
532
  available_transformations = [
533
  "day_to_night", "night_to_day",
534
  "clear_to_foggy", "foggy_to_clear",
 
536
  "summer_to_winter", "winter_to_summer"
537
  ]
538
 
539
+ # User-friendly transformation names
540
+ transformation_labels = {
541
+ "day_to_night": "πŸŒ…β†’πŸŒ™ Day to Night",
542
+ "night_to_day": "πŸŒ™β†’πŸŒ… Night to Day",
543
+ "clear_to_foggy": "β˜€οΈβ†’πŸŒ«οΈ Clear to Foggy",
544
+ "foggy_to_clear": "πŸŒ«οΈβ†’β˜€οΈ Foggy to Clear",
545
+ "photo_to_japanese": "πŸ“·β†’πŸŽ¨ Photo to Japanese Art",
546
+ "japanese_to_photo": "πŸŽ¨β†’πŸ“· Japanese Art to Photo",
547
+ "summer_to_winter": "πŸŒΏβ†’β„οΈ Summer to Winter",
548
+ "winter_to_summer": "β„οΈβ†’πŸŒΏ Winter to Summer"
549
+ }
550
+
551
  # Create Gradio interface
552
  with gr.Blocks(title="Intelligent Multi-Attribute Style Transfer", theme=gr.themes.Soft()) as demo:
553
  gr.Markdown("# 🎨 Intelligent Multi-Attribute Style Transfer")
554
+ gr.Markdown("Upload an image and our AI will analyze it to provide smart suggestions - **but you can choose ANY transformation you want!**")
555
+ gr.Markdown("πŸ’‘ **Tip:** You can skip analysis and apply transformations directly!")
556
 
557
  # Show available transformations
558
  gr.Markdown("## Available Transformations:")
 
560
  gr.Markdown("β€’ 🎨 Photo ↔ Japanese ukiyo-e art style (CycleGAN)")
561
  gr.Markdown("β€’ 🌫️ Foggy ↔ Clear weather transformation (CycleGAN)")
562
  gr.Markdown("β€’ 🌿 Summer ↔ Winter seasonal atmosphere (CycleGAN)")
563
+ gr.Markdown("---")
564
 
565
  with gr.Row():
566
  with gr.Column(scale=1):
567
  image_input = gr.Image(label="πŸ“€ Upload Your Image", type="pil")
568
+ analyze_btn = gr.Button("πŸ” Analyze Image (Optional)", variant="primary")
569
 
570
  with gr.Column(scale=1):
571
  analysis_output = gr.Markdown("## πŸ“Š Image Analysis Results", label="Analysis Results")
572
  recommendations = gr.CheckboxGroup(
573
+ choices=[(transformation_labels[t], t) for t in available_transformations],
574
+ label="🎨 Choose Transformations (All Available)",
575
+ visible=True,
576
+ value=None
577
  )
578
 
579
  with gr.Row():