Spaces:
Configuration error
Configuration error
| # train_model.py | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.preprocessing import LabelEncoder | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report | |
| import joblib | |
| import pandas as pd | |
| import os | |
| # 1. Dataset Preparation | |
| print("=== Preparing Dataset ===") | |
| data = { | |
| 'face_shape': ['Oval', 'Round', 'Square'] * 100, | |
| 'skin_tone': ['Fair', 'Medium', 'Dark'] * 100, | |
| 'face_size': ['Small', 'Medium', 'Large'] * 100, | |
| 'mask_style': ['Glitter', 'Animal', 'Floral'] * 100, | |
| 'mask_image': ['masks/glitter.png', 'masks/animal.png', 'masks/floral.png'] * 100 | |
| } | |
| df = pd.DataFrame(data) | |
| print(f"Dataset created with {len(df)} samples") | |
| # 2. Initialize Encoders with Image Mappings | |
| print("\n=== Initializing Encoders ===") | |
| encoders = { | |
| 'face_shape': LabelEncoder().fit(df['face_shape']), | |
| 'skin_tone': LabelEncoder().fit(df['skin_tone']), | |
| 'face_size': LabelEncoder().fit(df['face_size']), | |
| 'mask_style': LabelEncoder().fit(df['mask_style']), | |
| 'mask_images': { | |
| 0: 'masks/glitter.png', | |
| 1: 'masks/animal.png', | |
| 2: 'masks/floral.png' | |
| } | |
| } | |
| # 3. Feature Engineering | |
| print("\n=== Encoding Features ===") | |
| X = pd.DataFrame({ | |
| col: encoders[col].transform(df[col]) | |
| for col in ['face_shape', 'skin_tone', 'face_size'] | |
| }) | |
| y = encoders['mask_style'].transform(df['mask_style']) | |
| # 4. Train-Test Split | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42 | |
| ) | |
| print(f"Train samples: {len(X_train)}, Test samples: {len(X_test)}") | |
| # 5. Model Training (Enhanced Parameters) | |
| print("\n=== Training Model ===") | |
| model = RandomForestClassifier( | |
| n_estimators=150, # Increased from 50 | |
| max_depth=7, # Increased from 5 | |
| min_samples_split=5, # New parameter | |
| class_weight='balanced', # Handle imbalanced data | |
| random_state=42 | |
| ) | |
| model.fit(X_train, y_train) | |
| # 6. Enhanced Evaluation | |
| print("\n=== Model Evaluation ===") | |
| print(f"Training Accuracy: {model.score(X_train, y_train):.2f}") | |
| print(f"Test Accuracy: {model.score(X_test, y_test):.2f}") | |
| # Feature Importance | |
| print("\nFeature Importances:") | |
| for col, imp in zip(X.columns, model.feature_importances_): | |
| print(f"- {col}: {imp:.3f}") | |
| # Classification Report | |
| print("\nClassification Report:") | |
| print(classification_report(y_test, model.predict(X_test))) | |
| # 7. Save Model with Verification | |
| print("\n=== Saving Assets ===") | |
| os.makedirs('model', exist_ok=True) | |
| os.makedirs('masks', exist_ok=True) | |
| # Verify mask images exist | |
| print("\nMask Image Verification:") | |
| for class_idx, path in encoders['mask_images'].items(): | |
| if os.path.exists(path): | |
| print(f"✓ {encoders['mask_style'].classes_[class_idx]}: {path}") | |
| else: | |
| print(f"✗ Missing: {path}") | |
| # Save files | |
| joblib.dump(model, 'model/random_forest.pkl', protocol=4) | |
| joblib.dump(encoders, 'model/label_encoders.pkl', protocol=4) | |
| print("\n=== Saved Files ===") | |
| print("Model: model/random_forest.pkl") | |
| print("Encoders: model/label_encoders.pkl") | |
| print("\nClass Mappings:") | |
| print("- Face Shapes:", list(encoders['face_shape'].classes_)) | |
| print("- Mask Styles:", list(encoders['mask_style'].classes_)) | |
| print("- Mask Images:", encoders['mask_images']) |