File size: 3,695 Bytes
27ec8ce
1782b9a
ff5af75
d0ff2ca
 
 
 
 
29a3b4c
11b97bf
29a3b4c
 
 
 
 
 
 
 
 
 
 
 
 
11b97bf
 
 
 
 
 
29a3b4c
 
 
 
ff5af75
d0ff2ca
ff5af75
 
 
 
 
 
d0ff2ca
11b97bf
 
 
d0ff2ca
 
ff5af75
 
 
d0ff2ca
 
1782b9a
 
11b97bf
ff5af75
 
 
 
 
 
 
 
 
 
 
11b97bf
 
 
 
 
 
 
 
 
 
 
ff5af75
 
 
11b97bf
1782b9a
ff5af75
 
 
11b97bf
 
 
cd294bb
 
 
11b97bf
5322998
 
11b97bf
5322998
cd294bb
 
11b97bf
cd294bb
11b97bf
ff5af75
d0ff2ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import joblib
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from utils import extract_features

def initialize_fallback_model():
    """Creates and trains a simple fallback model with image support"""
    print("Initializing fallback model...")
    
    # Simple training data
    X = np.array([[0,0,0], [1,1,1], [2,2,2]])  # Dummy encoded features
    y = np.array([0, 1, 0])  # Dummy target
    
    model = RandomForestClassifier(n_estimators=10)
    model.fit(X, y)
    
    encoders = {
        'face_shape': LabelEncoder().fit(['Oval', 'Round', 'Square']),
        'skin_tone': LabelEncoder().fit(['Fair', 'Medium', 'Dark']),
        'face_size': LabelEncoder().fit(['Small', 'Medium', 'Large']),
        'mask_style': LabelEncoder().fit(['Glitter', 'Animal', 'Floral']),
        'mask_images': {
            0: 'masks/glitter.png',
            1: 'masks/animal.png',
            2: 'masks/floral.png'
        }
    }
    
    return model, encoders

def safe_load_model():
    """Safely loads model files with comprehensive fallback"""
    try:
        if not all(os.path.exists(f'model/{f}') for f in ['random_forest.pkl', 'label_encoders.pkl']):
            raise FileNotFoundError("Model files missing")
            
        model = joblib.load('model/random_forest.pkl', mmap_mode='r')
        encoders = joblib.load('model/label_encoders.pkl', mmap_mode='r')
        
        # Verify model is fitted and has image mappings
        if not hasattr(model, 'classes_') or 'mask_images' not in encoders:
            raise ValueError("Model or encoders incomplete")
            
        print("Main model loaded successfully!")
        return model, encoders
        
    except Exception as e:
        print(f"Loading failed: {str(e)}")
        return initialize_fallback_model()

def recommend_mask(image):
    """Process image and return both text and image recommendations"""
    try:
        # Extract features
        face_shape, skin_tone, face_size = extract_features(image)
        
        # Encode features
        face_encoded = encoders["face_shape"].transform([face_shape])[0]
        skin_encoded = encoders["skin_tone"].transform([skin_tone])[0]
        size_encoded = encoders["face_size"].transform([face_size])[0]
        
        # Predict
        prediction = model.predict([[face_encoded, skin_encoded, size_encoded]])[0]
        
        # Get both text and image recommendations
        mask_style = encoders["mask_style"].classes_[prediction]
        mask_image = encoders['mask_images'].get(prediction, 'masks/default.png')
        
        # Verify image exists
        if not os.path.exists(mask_image):
            print(f"Warning: Mask image not found at {mask_image}")
            mask_image = 'masks/default.png'
            
        return mask_style, mask_image
        
    except Exception as e:
        print(f"Prediction error: {str(e)}")
        return "Basic Mask (Fallback)", "masks/default.png"

# Initialize model and encoders
model, encoders = safe_load_model()

# Verify masks directory exists
os.makedirs('masks', exist_ok=True)

# Create Gradio interface
demo = gr.Interface(
    fn=recommend_mask,
    inputs=gr.Image(label="Upload Your Face", type="filepath"),
    outputs=[
        gr.Textbox(label="Recommended Style"),
        gr.Image(label="Mask Preview", type="filepath")
    ],
    title="🎭 AI Party Mask Recommender",
    description="Upload a photo to get a personalized mask recommendation!",
    examples=[["example_face.jpg"]] if os.path.exists("example_face.jpg") else None
)

if __name__ == "__main__":
    demo.launch()