""" app.py ====== API REST con Flask que expone el modelo SVM entrenado. Endpoints: GET /health → Estado de la API y del modelo GET /model/info → Información del modelo (clases, componentes PCA, etc.) POST /predict → Predice una imagen (array 784 píxeles, valores 0-255) POST /predict/batch → Predice múltiples imágenes a la vez Iniciar el servidor: python app.py """ import os import pickle import time import numpy as np from flask import Flask, jsonify, request # ── Configuración ───────────────────────────────────────────────────────────── MODELS_DIR = "models" HOST = "0.0.0.0" PORT = 5000 DEBUG = True # Cambiar a False en producción CLASS_NAMES = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] app = Flask(__name__) # ── Cargar modelos al iniciar ───────────────────────────────────────────────── def load_artifacts(): scaler_path = os.path.join(MODELS_DIR, "scaler.pkl") pca_path = os.path.join(MODELS_DIR, "pca.pkl") model_path = os.path.join(MODELS_DIR, "svm_model.pkl") missing = [p for p in (scaler_path, pca_path, model_path) if not os.path.exists(p)] if missing: raise FileNotFoundError( f"Artefactos no encontrados: {missing}\n" "Ejecuta primero: python train_model.py" ) with open(scaler_path, "rb") as f: scaler = pickle.load(f) with open(pca_path, "rb") as f: pca = pickle.load(f) with open(model_path, "rb") as f: clf = pickle.load(f) return scaler, pca, clf print("⏳ Cargando artefactos del modelo...") try: scaler, pca, clf = load_artifacts() MODEL_LOADED = True print("✅ Modelo cargado correctamente.") except FileNotFoundError as e: MODEL_LOADED = False print(f"❌ {e}") # ── Helpers ─────────────────────────────────────────────────────────────────── def preprocess(pixels: list) -> np.ndarray: """Escala y aplica PCA a un vector de 784 píxeles.""" arr = np.array(pixels, dtype=np.float64).reshape(1, -1) arr = scaler.transform(arr) arr = pca.transform(arr) return arr def preprocess_batch(batch: list) -> np.ndarray: """Escala y aplica PCA a un lote de vectores.""" arr = np.array(batch, dtype=np.float64) arr = scaler.transform(arr) arr = pca.transform(arr) return arr # ── Endpoints ───────────────────────────────────────────────────────────────── @app.route("/health", methods=["GET"]) def health(): """Verificar que la API y el modelo están operativos.""" return jsonify({ "status": "ok" if MODEL_LOADED else "degraded", "model_loaded": MODEL_LOADED, "message": "API operativa" if MODEL_LOADED else "Modelo no cargado. Ejecuta train_model.py", }), 200 if MODEL_LOADED else 503 @app.route("/model/info", methods=["GET"]) def model_info(): """Devuelve metadatos del modelo cargado.""" if not MODEL_LOADED: return jsonify({"error": "Modelo no disponible"}), 503 return jsonify({ "model_type": type(clf).__name__, "kernel": clf.kernel, "C": clf.C, "gamma": clf.gamma, "classes": CLASS_NAMES, "n_classes": len(CLASS_NAMES), "pca_components": int(pca.n_components_), "pca_variance_ratio": float(np.sum(pca.explained_variance_ratio_)), "input_features": 784, }) @app.route("/predict", methods=["POST"]) def predict(): """ Predice la clase de una imagen de Fashion-MNIST. Body JSON: { "pixels": [0, 0, 128, ..., 255] // lista de 784 valores int 0-255 } Respuesta JSON: { "class_id": 7, "class_name": "Sneaker", "confidence": 0.94, "probabilities": { "T-shirt/top": 0.01, ... } } """ if not MODEL_LOADED: return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503 data = request.get_json(force=True) if not data or "pixels" not in data: return jsonify({"error": "Se requiere el campo 'pixels' con 784 valores."}), 400 pixels = data["pixels"] if len(pixels) != 784: return jsonify({"error": f"Se esperaban 784 píxeles, se recibieron {len(pixels)}."}), 400 try: t0 = time.perf_counter() X = preprocess(pixels) pred = clf.predict(X)[0] proba = clf.predict_proba(X)[0] elapsed = round((time.perf_counter() - t0) * 1000, 2) class_id = int(pred) class_name = CLASS_NAMES[class_id] confidence = round(float(proba[class_id]), 4) prob_dict = {CLASS_NAMES[i]: round(float(p), 4) for i, p in enumerate(proba)} return jsonify({ "class_id": class_id, "class_name": class_name, "confidence": confidence, "probabilities": prob_dict, "inference_ms": elapsed, }) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/predict/batch", methods=["POST"]) def predict_batch(): """ Predice múltiples imágenes en una sola llamada. Body JSON: { "images": [ [0, 0, 128, ..., 255], // 784 valores por imagen [...] ] } Respuesta JSON: { "results": [ {"class_id": 7, "class_name": "Sneaker", "confidence": 0.94}, ... ], "count": 2 } """ if not MODEL_LOADED: return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503 data = request.get_json(force=True) if not data or "images" not in data: return jsonify({"error": "Se requiere el campo 'images' con una lista de arrays de 784 valores."}), 400 images = data["images"] if not isinstance(images, list) or len(images) == 0: return jsonify({"error": "'images' debe ser una lista no vacía."}), 400 errors = [] for i, img in enumerate(images): if len(img) != 784: errors.append(f"Imagen [{i}]: se esperaban 784 píxeles, se recibieron {len(img)}.") if errors: return jsonify({"error": errors}), 400 try: t0 = time.perf_counter() X = preprocess_batch(images) preds = clf.predict(X) probas = clf.predict_proba(X) elapsed = round((time.perf_counter() - t0) * 1000, 2) results = [] for pred, proba in zip(preds, probas): class_id = int(pred) class_name = CLASS_NAMES[class_id] confidence = round(float(proba[class_id]), 4) results.append({ "class_id": class_id, "class_name": class_name, "confidence": confidence, }) return jsonify({ "results": results, "count": len(results), "inference_ms": elapsed, }) except Exception as e: return jsonify({"error": str(e)}), 500 # ── Main ────────────────────────────────────────────────────────────────────── if __name__ == "__main__": print(f"\n🌐 Iniciando API en http://{HOST}:{PORT}") print(" Endpoints disponibles:") print(" GET /health") print(" GET /model/info") print(" POST /predict") print(" POST /predict/batch\n") app.run(host=HOST, port=PORT, debug=DEBUG)