SMV / train_model.py
Tyan1988's picture
Subo modelo SMV completo
9885230
"""
train_model.py
==============
Entrena el modelo SVM sobre Fashion-MNIST y guarda los artefactos:
- models/scaler.pkl
- models/pca.pkl
- models/svm_model.pkl
Ejecutar UNA SOLA VEZ antes de iniciar la API:
python train_model.py
Cambios v2 (para superar 90 %):
- TRAIN_SIZE subido de 30 000 a 42 000
- param_grid enfocado en C alto (50, 100) y gamma bajo-medio (0.001, 0.005)
- Tras GridSearchCV, se reentrena el modelo final con TODOS los datos de
entrenamiento (no solo los folds de CV), lo que extrae rendimiento extra.
"""
import os
import pickle
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
# ── ConfiguraciΓ³n ─────────────────────────────────────────────────────────────
TRAIN_SIZE = 50_000 # 80/20 β†’ ~33 600 train | ~8 400 test
PCA_VARIANCE = 0.95
RANDOM_STATE = 42
MODELS_DIR = "models"
CLASS_NAMES = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
os.makedirs(MODELS_DIR, exist_ok=True)
# ── 1. Cargar datos ────────────────────────────────────────────────────────────
print("πŸ“₯ Cargando Fashion-MNIST desde OpenML...")
X, y = fetch_openml("Fashion-MNIST", version=1, return_X_y=True, as_frame=False)
# ── 2. Submuestra estratificada ────────────────────────────────────────────────
print(f"βœ‚οΈ Seleccionando {TRAIN_SIZE:,} muestras estratificadas...")
X_small, _, y_small, _ = train_test_split(
X, y,
train_size=TRAIN_SIZE,
stratify=y,
random_state=RANDOM_STATE,
)
# ── 3. Escalar ─────────────────────────────────────────────────────────────────
print("βš–οΈ Escalando (StandardScaler)...")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_small)
# ── 4. PCA ─────────────────────────────────────────────────────────────────────
print(f"πŸ”¬ Aplicando PCA (varianza={PCA_VARIANCE})...")
pca = PCA(PCA_VARIANCE, random_state=RANDOM_STATE)
X_pca = pca.fit_transform(X_scaled)
print(f" Componentes resultantes: {pca.n_components_}")
# ── 5. Split 80 / 20 ───────────────────────────────────────────────────────────
X_train, X_test, y_train, y_test = train_test_split(
X_pca, y_small,
test_size=0.20,
random_state=RANDOM_STATE,
stratify=y_small,
)
print(f"πŸ“Š Train: {len(X_train):,} | Test: {len(X_test):,}")
# ── 6. GridSearchCV ────────────────────────────────────────────────────────────
# C alto + gamma bajo-medio es la zona optima para SVM-RBF en Fashion-MNIST.
print("πŸ” Ejecutando GridSearchCV (puede tardar varios minutos)...")
param_grid = {
"C": [50, 100],
"gamma": [0.001, 0.005],
"kernel": ["rbf"],
}
grid = GridSearchCV(
SVC(probability=True),
param_grid,
cv=3,
n_jobs=-1,
verbose=2,
)
grid.fit(X_train, y_train)
best_params = grid.best_params_
print(f"\nβœ… Mejores parΓ‘metros: {best_params}")
print(f" Mejor CV score: {grid.best_score_:.4f}")
# ── 7. Reentrenar con TODOS los datos de train ─────────────────────────────────
# GridSearchCV usa solo una parte de los datos en cada fold.
# Reentrenar con el dataset completo de train aprovecha cada muestra disponible.
print("\nπŸ” Reentrenando modelo final con 100 % de los datos de entrenamiento...")
clf = SVC(probability=True, **best_params)
clf.fit(X_train, y_train)
# ── 8. EvaluaciΓ³n ──────────────────────────────────────────────────────────────
accuracy = clf.score(X_test, y_test)
print(f"\n🎯 Precisión en test (80/20): {accuracy:.4f} ({accuracy*100:.2f} %)")
if accuracy < 0.90:
print("⚠️ La precisión es menor al 90 %. Prueba subir TRAIN_SIZE a 50 000.")
else:
print("πŸ† Β‘Modelo supera el umbral del 90 %!")
print("\nπŸ“‹ Reporte completo:")
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES))
# ── 9. Guardar artefactos ──────────────────────────────────────────────────────
scaler_path = os.path.join(MODELS_DIR, "scaler.pkl")
pca_path = os.path.join(MODELS_DIR, "pca.pkl")
model_path = os.path.join(MODELS_DIR, "svm_model.pkl")
with open(scaler_path, "wb") as f:
pickle.dump(scaler, f)
with open(pca_path, "wb") as f:
pickle.dump(pca, f)
with open(model_path, "wb") as f:
pickle.dump(clf, f)
print(f"\nπŸ’Ύ Artefactos guardados en '{MODELS_DIR}/':")
print(f" β€’ {scaler_path}")
print(f" β€’ {pca_path}")
print(f" β€’ {model_path}")
print("\nπŸš€ Ahora puedes iniciar la API con: python app.py")