RICHERGIRL commited on
Commit
11b97bf
·
verified ·
1 Parent(s): 32d1a02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -7,7 +7,7 @@ from sklearn.preprocessing import LabelEncoder
7
  from utils import extract_features
8
 
9
  def initialize_fallback_model():
10
- """Creates and trains a simple fallback model"""
11
  print("Initializing fallback model...")
12
 
13
  # Simple training data
@@ -21,7 +21,12 @@ def initialize_fallback_model():
21
  'face_shape': LabelEncoder().fit(['Oval', 'Round', 'Square']),
22
  'skin_tone': LabelEncoder().fit(['Fair', 'Medium', 'Dark']),
23
  'face_size': LabelEncoder().fit(['Small', 'Medium', 'Large']),
24
- 'mask_style': LabelEncoder().fit(['StyleA', 'StyleB', 'StyleC']) # Added mask_style
 
 
 
 
 
25
  }
26
 
27
  return model, encoders
@@ -35,9 +40,9 @@ def safe_load_model():
35
  model = joblib.load('model/random_forest.pkl', mmap_mode='r')
36
  encoders = joblib.load('model/label_encoders.pkl', mmap_mode='r')
37
 
38
- # Verify model is fitted
39
- if not hasattr(model, 'classes_'):
40
- raise ValueError("Model not properly trained")
41
 
42
  print("Main model loaded successfully!")
43
  return model, encoders
@@ -47,7 +52,7 @@ def safe_load_model():
47
  return initialize_fallback_model()
48
 
49
  def recommend_mask(image):
50
- """Process image and make prediction with error handling"""
51
  try:
52
  # Extract features
53
  face_shape, skin_tone, face_size = extract_features(image)
@@ -59,33 +64,40 @@ def recommend_mask(image):
59
 
60
  # Predict
61
  prediction = model.predict([[face_encoded, skin_encoded, size_encoded]])[0]
62
- return encoders["mask_style"].classes_[prediction]
63
-
64
- # Get recommended mask image path
65
- mask_image_path = encoders['mask_images'][prediction]
66
- return (
67
- encoders["mask_style"].classes_[prediction], # Text
68
- mask_image_path # Image
69
- )
 
 
 
70
 
71
  except Exception as e:
72
  print(f"Prediction error: {str(e)}")
73
- return "Error", "default_mask.png" # Fallback
74
-
75
 
76
  # Initialize model and encoders
77
  model, encoders = safe_load_model()
78
 
 
 
 
79
  # Create Gradio interface
80
  demo = gr.Interface(
81
  fn=recommend_mask,
82
- inputs=gr.Image(type="filepath"),
83
  outputs=[
84
  gr.Textbox(label="Recommended Style"),
85
- gr.Image(label="Mask Preview") # Add image output
86
  ],
87
  title="🎭 AI Party Mask Recommender",
88
  description="Upload a photo to get a personalized mask recommendation!",
 
89
  )
 
90
  if __name__ == "__main__":
91
  demo.launch()
 
7
  from utils import extract_features
8
 
9
  def initialize_fallback_model():
10
+ """Creates and trains a simple fallback model with image support"""
11
  print("Initializing fallback model...")
12
 
13
  # Simple training data
 
21
  'face_shape': LabelEncoder().fit(['Oval', 'Round', 'Square']),
22
  'skin_tone': LabelEncoder().fit(['Fair', 'Medium', 'Dark']),
23
  'face_size': LabelEncoder().fit(['Small', 'Medium', 'Large']),
24
+ 'mask_style': LabelEncoder().fit(['Glitter', 'Animal', 'Floral']),
25
+ 'mask_images': {
26
+ 0: 'masks/glitter.png',
27
+ 1: 'masks/animal.png',
28
+ 2: 'masks/floral.png'
29
+ }
30
  }
31
 
32
  return model, encoders
 
40
  model = joblib.load('model/random_forest.pkl', mmap_mode='r')
41
  encoders = joblib.load('model/label_encoders.pkl', mmap_mode='r')
42
 
43
+ # Verify model is fitted and has image mappings
44
+ if not hasattr(model, 'classes_') or 'mask_images' not in encoders:
45
+ raise ValueError("Model or encoders incomplete")
46
 
47
  print("Main model loaded successfully!")
48
  return model, encoders
 
52
  return initialize_fallback_model()
53
 
54
  def recommend_mask(image):
55
+ """Process image and return both text and image recommendations"""
56
  try:
57
  # Extract features
58
  face_shape, skin_tone, face_size = extract_features(image)
 
64
 
65
  # Predict
66
  prediction = model.predict([[face_encoded, skin_encoded, size_encoded]])[0]
67
+
68
+ # Get both text and image recommendations
69
+ mask_style = encoders["mask_style"].classes_[prediction]
70
+ mask_image = encoders['mask_images'].get(prediction, 'masks/default.png')
71
+
72
+ # Verify image exists
73
+ if not os.path.exists(mask_image):
74
+ print(f"Warning: Mask image not found at {mask_image}")
75
+ mask_image = 'masks/default.png'
76
+
77
+ return mask_style, mask_image
78
 
79
  except Exception as e:
80
  print(f"Prediction error: {str(e)}")
81
+ return "Basic Mask (Fallback)", "masks/default.png"
 
82
 
83
  # Initialize model and encoders
84
  model, encoders = safe_load_model()
85
 
86
+ # Verify masks directory exists
87
+ os.makedirs('masks', exist_ok=True)
88
+
89
  # Create Gradio interface
90
  demo = gr.Interface(
91
  fn=recommend_mask,
92
+ inputs=gr.Image(label="Upload Your Face", type="filepath"),
93
  outputs=[
94
  gr.Textbox(label="Recommended Style"),
95
+ gr.Image(label="Mask Preview", type="filepath")
96
  ],
97
  title="🎭 AI Party Mask Recommender",
98
  description="Upload a photo to get a personalized mask recommendation!",
99
+ examples=[["example_face.jpg"]] if os.path.exists("example_face.jpg") else None
100
  )
101
+
102
  if __name__ == "__main__":
103
  demo.launch()