Spaces:
Sleeping
Sleeping
| """ | |
| Train a digit classifier on sklearn's digits dataset. | |
| Run this script locally to generate the model file. | |
| Usage: | |
| python train_model.py | |
| """ | |
| import os | |
| from sklearn.datasets import load_digits | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.neural_network import MLPClassifier | |
| from sklearn.metrics import accuracy_score, classification_report | |
| import joblib | |
| def train_digit_classifier(): | |
| """Train an MLP classifier on the sklearn digits dataset (8x8 images).""" | |
| print("Loading digits dataset...") | |
| digits = load_digits() | |
| X, y = digits.data, digits.target | |
| print(f"Dataset shape: {X.shape}") | |
| print(f"Number of classes: {len(set(y))}") | |
| print(f"Image size: 8x8 (64 features)") | |
| # Split data | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42, stratify=y | |
| ) | |
| print(f"\nTraining samples: {len(X_train)}") | |
| print(f"Test samples: {len(X_test)}") | |
| # Train MLP classifier | |
| print("\nTraining MLP classifier...") | |
| model = MLPClassifier( | |
| hidden_layer_sizes=(128, 64), | |
| activation='relu', | |
| max_iter=500, | |
| random_state=42, | |
| verbose=True | |
| ) | |
| model.fit(X_train, y_train) | |
| # Evaluate | |
| y_pred = model.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| print(f"\n{'='*50}") | |
| print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") | |
| print(f"{'='*50}") | |
| print("\nClassification Report:") | |
| print(classification_report(y_test, y_pred)) | |
| # Save model | |
| model_dir = os.path.join(os.path.dirname(__file__), "src", "model") | |
| os.makedirs(model_dir, exist_ok=True) | |
| model_path = os.path.join(model_dir, "digit_classifier.joblib") | |
| joblib.dump(model, model_path) | |
| print(f"\nModel saved to: {model_path}") | |
| # Check file size | |
| file_size = os.path.getsize(model_path) / 1024 | |
| print(f"Model file size: {file_size:.2f} KB") | |
| return model | |
| if __name__ == "__main__": | |
| train_digit_classifier() | |