another-demo / train_model.py
Vincimus's picture
Add handwritten digit recognizer with MLP classifier
cd8e368
"""
Train a digit classifier on sklearn's digits dataset.
Run this script locally to generate the model file.
Usage:
python train_model.py
"""
import os
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, classification_report
import joblib
def train_digit_classifier():
"""Train an MLP classifier on the sklearn digits dataset (8x8 images)."""
print("Loading digits dataset...")
digits = load_digits()
X, y = digits.data, digits.target
print(f"Dataset shape: {X.shape}")
print(f"Number of classes: {len(set(y))}")
print(f"Image size: 8x8 (64 features)")
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"\nTraining samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
# Train MLP classifier
print("\nTraining MLP classifier...")
model = MLPClassifier(
hidden_layer_sizes=(128, 64),
activation='relu',
max_iter=500,
random_state=42,
verbose=True
)
model.fit(X_train, y_train)
# Evaluate
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n{'='*50}")
print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"{'='*50}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
# Save model
model_dir = os.path.join(os.path.dirname(__file__), "src", "model")
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, "digit_classifier.joblib")
joblib.dump(model, model_path)
print(f"\nModel saved to: {model_path}")
# Check file size
file_size = os.path.getsize(model_path) / 1024
print(f"Model file size: {file_size:.2f} KB")
return model
if __name__ == "__main__":
train_digit_classifier()