""" train_model.py ============== Entrena el modelo SVM sobre Fashion-MNIST y guarda los artefactos: - models/scaler.pkl - models/pca.pkl - models/svm_model.pkl Ejecutar UNA SOLA VEZ antes de iniciar la API: python train_model.py Cambios v2 (para superar 90 %): - TRAIN_SIZE subido de 30 000 a 42 000 - param_grid enfocado en C alto (50, 100) y gamma bajo-medio (0.001, 0.005) - Tras GridSearchCV, se reentrena el modelo final con TODOS los datos de entrenamiento (no solo los folds de CV), lo que extrae rendimiento extra. """ import os import pickle import numpy as np from sklearn.datasets import fetch_openml from sklearn.decomposition import PCA from sklearn.metrics import classification_report from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC # ── Configuración ───────────────────────────────────────────────────────────── TRAIN_SIZE = 50_000 # 80/20 → ~33 600 train | ~8 400 test PCA_VARIANCE = 0.95 RANDOM_STATE = 42 MODELS_DIR = "models" CLASS_NAMES = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] os.makedirs(MODELS_DIR, exist_ok=True) # ── 1. Cargar datos ──────────────────────────────────────────────────────────── print("📥 Cargando Fashion-MNIST desde OpenML...") X, y = fetch_openml("Fashion-MNIST", version=1, return_X_y=True, as_frame=False) # ── 2. Submuestra estratificada ──────────────────────────────────────────────── print(f"✂️ Seleccionando {TRAIN_SIZE:,} muestras estratificadas...") X_small, _, y_small, _ = train_test_split( X, y, train_size=TRAIN_SIZE, stratify=y, random_state=RANDOM_STATE, ) # ── 3. Escalar ───────────────────────────────────────────────────────────────── print("⚖️ Escalando (StandardScaler)...") scaler = StandardScaler() X_scaled = scaler.fit_transform(X_small) # ── 4. PCA ───────────────────────────────────────────────────────────────────── print(f"🔬 Aplicando PCA (varianza={PCA_VARIANCE})...") pca = PCA(PCA_VARIANCE, random_state=RANDOM_STATE) X_pca = pca.fit_transform(X_scaled) print(f" Componentes resultantes: {pca.n_components_}") # ── 5. Split 80 / 20 ─────────────────────────────────────────────────────────── X_train, X_test, y_train, y_test = train_test_split( X_pca, y_small, test_size=0.20, random_state=RANDOM_STATE, stratify=y_small, ) print(f"📊 Train: {len(X_train):,} | Test: {len(X_test):,}") # ── 6. GridSearchCV ──────────────────────────────────────────────────────────── # C alto + gamma bajo-medio es la zona optima para SVM-RBF en Fashion-MNIST. print("🔍 Ejecutando GridSearchCV (puede tardar varios minutos)...") param_grid = { "C": [50, 100], "gamma": [0.001, 0.005], "kernel": ["rbf"], } grid = GridSearchCV( SVC(probability=True), param_grid, cv=3, n_jobs=-1, verbose=2, ) grid.fit(X_train, y_train) best_params = grid.best_params_ print(f"\n✅ Mejores parámetros: {best_params}") print(f" Mejor CV score: {grid.best_score_:.4f}") # ── 7. Reentrenar con TODOS los datos de train ───────────────────────────────── # GridSearchCV usa solo una parte de los datos en cada fold. # Reentrenar con el dataset completo de train aprovecha cada muestra disponible. print("\n🔁 Reentrenando modelo final con 100 % de los datos de entrenamiento...") clf = SVC(probability=True, **best_params) clf.fit(X_train, y_train) # ── 8. Evaluación ────────────────────────────────────────────────────────────── accuracy = clf.score(X_test, y_test) print(f"\n🎯 Precisión en test (80/20): {accuracy:.4f} ({accuracy*100:.2f} %)") if accuracy < 0.90: print("⚠️ La precisión es menor al 90 %. Prueba subir TRAIN_SIZE a 50 000.") else: print("🏆 ¡Modelo supera el umbral del 90 %!") print("\n📋 Reporte completo:") y_pred = clf.predict(X_test) print(classification_report(y_test, y_pred, target_names=CLASS_NAMES)) # ── 9. Guardar artefactos ────────────────────────────────────────────────────── scaler_path = os.path.join(MODELS_DIR, "scaler.pkl") pca_path = os.path.join(MODELS_DIR, "pca.pkl") model_path = os.path.join(MODELS_DIR, "svm_model.pkl") with open(scaler_path, "wb") as f: pickle.dump(scaler, f) with open(pca_path, "wb") as f: pickle.dump(pca, f) with open(model_path, "wb") as f: pickle.dump(clf, f) print(f"\n💾 Artefactos guardados en '{MODELS_DIR}/':") print(f" • {scaler_path}") print(f" • {pca_path}") print(f" • {model_path}") print("\n🚀 Ahora puedes iniciar la API con: python app.py")