File size: 5,856 Bytes
9885230 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
train_model.py
==============
Entrena el modelo SVM sobre Fashion-MNIST y guarda los artefactos:
- models/scaler.pkl
- models/pca.pkl
- models/svm_model.pkl
Ejecutar UNA SOLA VEZ antes de iniciar la API:
python train_model.py
Cambios v2 (para superar 90 %):
- TRAIN_SIZE subido de 30 000 a 42 000
- param_grid enfocado en C alto (50, 100) y gamma bajo-medio (0.001, 0.005)
- Tras GridSearchCV, se reentrena el modelo final con TODOS los datos de
entrenamiento (no solo los folds de CV), lo que extrae rendimiento extra.
"""
import os
import pickle
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
# ββ ConfiguraciΓ³n βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TRAIN_SIZE = 50_000 # 80/20 β ~33 600 train | ~8 400 test
PCA_VARIANCE = 0.95
RANDOM_STATE = 42
MODELS_DIR = "models"
CLASS_NAMES = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
os.makedirs(MODELS_DIR, exist_ok=True)
# ββ 1. Cargar datos ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("π₯ Cargando Fashion-MNIST desde OpenML...")
X, y = fetch_openml("Fashion-MNIST", version=1, return_X_y=True, as_frame=False)
# ββ 2. Submuestra estratificada ββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"βοΈ Seleccionando {TRAIN_SIZE:,} muestras estratificadas...")
X_small, _, y_small, _ = train_test_split(
X, y,
train_size=TRAIN_SIZE,
stratify=y,
random_state=RANDOM_STATE,
)
# ββ 3. Escalar βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("βοΈ Escalando (StandardScaler)...")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_small)
# ββ 4. PCA βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"π¬ Aplicando PCA (varianza={PCA_VARIANCE})...")
pca = PCA(PCA_VARIANCE, random_state=RANDOM_STATE)
X_pca = pca.fit_transform(X_scaled)
print(f" Componentes resultantes: {pca.n_components_}")
# ββ 5. Split 80 / 20 βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
X_train, X_test, y_train, y_test = train_test_split(
X_pca, y_small,
test_size=0.20,
random_state=RANDOM_STATE,
stratify=y_small,
)
print(f"π Train: {len(X_train):,} | Test: {len(X_test):,}")
# ββ 6. GridSearchCV ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# C alto + gamma bajo-medio es la zona optima para SVM-RBF en Fashion-MNIST.
print("π Ejecutando GridSearchCV (puede tardar varios minutos)...")
param_grid = {
"C": [50, 100],
"gamma": [0.001, 0.005],
"kernel": ["rbf"],
}
grid = GridSearchCV(
SVC(probability=True),
param_grid,
cv=3,
n_jobs=-1,
verbose=2,
)
grid.fit(X_train, y_train)
best_params = grid.best_params_
print(f"\nβ
Mejores parΓ‘metros: {best_params}")
print(f" Mejor CV score: {grid.best_score_:.4f}")
# ββ 7. Reentrenar con TODOS los datos de train βββββββββββββββββββββββββββββββββ
# GridSearchCV usa solo una parte de los datos en cada fold.
# Reentrenar con el dataset completo de train aprovecha cada muestra disponible.
print("\nπ Reentrenando modelo final con 100 % de los datos de entrenamiento...")
clf = SVC(probability=True, **best_params)
clf.fit(X_train, y_train)
# ββ 8. EvaluaciΓ³n ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
accuracy = clf.score(X_test, y_test)
print(f"\nπ― PrecisiΓ³n en test (80/20): {accuracy:.4f} ({accuracy*100:.2f} %)")
if accuracy < 0.90:
print("β οΈ La precisiΓ³n es menor al 90 %. Prueba subir TRAIN_SIZE a 50 000.")
else:
print("π Β‘Modelo supera el umbral del 90 %!")
print("\nπ Reporte completo:")
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES))
# ββ 9. Guardar artefactos ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
scaler_path = os.path.join(MODELS_DIR, "scaler.pkl")
pca_path = os.path.join(MODELS_DIR, "pca.pkl")
model_path = os.path.join(MODELS_DIR, "svm_model.pkl")
with open(scaler_path, "wb") as f:
pickle.dump(scaler, f)
with open(pca_path, "wb") as f:
pickle.dump(pca, f)
with open(model_path, "wb") as f:
pickle.dump(clf, f)
print(f"\nπΎ Artefactos guardados en '{MODELS_DIR}/':")
print(f" β’ {scaler_path}")
print(f" β’ {pca_path}")
print(f" β’ {model_path}")
print("\nπ Ahora puedes iniciar la API con: python app.py")
|