| """ |
| app.py |
| ====== |
| API REST con Flask que expone el modelo SVM entrenado. |
| |
| Endpoints: |
| GET /health β Estado de la API y del modelo |
| GET /model/info β InformaciΓ³n del modelo (clases, componentes PCA, etc.) |
| POST /predict β Predice una imagen (array 784 pΓxeles, valores 0-255) |
| POST /predict/batch β Predice mΓΊltiples imΓ‘genes a la vez |
| |
| Iniciar el servidor: |
| python app.py |
| """ |
|
|
| import os |
| import pickle |
| import time |
|
|
| import numpy as np |
| from flask import Flask, jsonify, request |
|
|
| |
| MODELS_DIR = "models" |
| HOST = "0.0.0.0" |
| PORT = 5000 |
| DEBUG = True |
|
|
| CLASS_NAMES = [ |
| "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", |
| "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", |
| ] |
|
|
| app = Flask(__name__) |
|
|
| |
| def load_artifacts(): |
| scaler_path = os.path.join(MODELS_DIR, "scaler.pkl") |
| pca_path = os.path.join(MODELS_DIR, "pca.pkl") |
| model_path = os.path.join(MODELS_DIR, "svm_model.pkl") |
|
|
| missing = [p for p in (scaler_path, pca_path, model_path) if not os.path.exists(p)] |
| if missing: |
| raise FileNotFoundError( |
| f"Artefactos no encontrados: {missing}\n" |
| "Ejecuta primero: python train_model.py" |
| ) |
|
|
| with open(scaler_path, "rb") as f: |
| scaler = pickle.load(f) |
| with open(pca_path, "rb") as f: |
| pca = pickle.load(f) |
| with open(model_path, "rb") as f: |
| clf = pickle.load(f) |
|
|
| return scaler, pca, clf |
|
|
| print("β³ Cargando artefactos del modelo...") |
| try: |
| scaler, pca, clf = load_artifacts() |
| MODEL_LOADED = True |
| print("β
Modelo cargado correctamente.") |
| except FileNotFoundError as e: |
| MODEL_LOADED = False |
| print(f"β {e}") |
|
|
|
|
| |
| def preprocess(pixels: list) -> np.ndarray: |
| """Escala y aplica PCA a un vector de 784 pΓxeles.""" |
| arr = np.array(pixels, dtype=np.float64).reshape(1, -1) |
| arr = scaler.transform(arr) |
| arr = pca.transform(arr) |
| return arr |
|
|
|
|
| def preprocess_batch(batch: list) -> np.ndarray: |
| """Escala y aplica PCA a un lote de vectores.""" |
| arr = np.array(batch, dtype=np.float64) |
| arr = scaler.transform(arr) |
| arr = pca.transform(arr) |
| return arr |
|
|
|
|
| |
|
|
| @app.route("/health", methods=["GET"]) |
| def health(): |
| """Verificar que la API y el modelo estΓ‘n operativos.""" |
| return jsonify({ |
| "status": "ok" if MODEL_LOADED else "degraded", |
| "model_loaded": MODEL_LOADED, |
| "message": "API operativa" if MODEL_LOADED |
| else "Modelo no cargado. Ejecuta train_model.py", |
| }), 200 if MODEL_LOADED else 503 |
|
|
|
|
| @app.route("/model/info", methods=["GET"]) |
| def model_info(): |
| """Devuelve metadatos del modelo cargado.""" |
| if not MODEL_LOADED: |
| return jsonify({"error": "Modelo no disponible"}), 503 |
|
|
| return jsonify({ |
| "model_type": type(clf).__name__, |
| "kernel": clf.kernel, |
| "C": clf.C, |
| "gamma": clf.gamma, |
| "classes": CLASS_NAMES, |
| "n_classes": len(CLASS_NAMES), |
| "pca_components": int(pca.n_components_), |
| "pca_variance_ratio": float(np.sum(pca.explained_variance_ratio_)), |
| "input_features": 784, |
| }) |
|
|
|
|
| @app.route("/predict", methods=["POST"]) |
| def predict(): |
| """ |
| Predice la clase de una imagen de Fashion-MNIST. |
| |
| Body JSON: |
| { |
| "pixels": [0, 0, 128, ..., 255] // lista de 784 valores int 0-255 |
| } |
| |
| Respuesta JSON: |
| { |
| "class_id": 7, |
| "class_name": "Sneaker", |
| "confidence": 0.94, |
| "probabilities": { "T-shirt/top": 0.01, ... } |
| } |
| """ |
| if not MODEL_LOADED: |
| return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503 |
|
|
| data = request.get_json(force=True) |
| if not data or "pixels" not in data: |
| return jsonify({"error": "Se requiere el campo 'pixels' con 784 valores."}), 400 |
|
|
| pixels = data["pixels"] |
| if len(pixels) != 784: |
| return jsonify({"error": f"Se esperaban 784 pΓxeles, se recibieron {len(pixels)}."}), 400 |
|
|
| try: |
| t0 = time.perf_counter() |
| X = preprocess(pixels) |
| pred = clf.predict(X)[0] |
| proba = clf.predict_proba(X)[0] |
| elapsed = round((time.perf_counter() - t0) * 1000, 2) |
|
|
| class_id = int(pred) |
| class_name = CLASS_NAMES[class_id] |
| confidence = round(float(proba[class_id]), 4) |
| prob_dict = {CLASS_NAMES[i]: round(float(p), 4) for i, p in enumerate(proba)} |
|
|
| return jsonify({ |
| "class_id": class_id, |
| "class_name": class_name, |
| "confidence": confidence, |
| "probabilities": prob_dict, |
| "inference_ms": elapsed, |
| }) |
|
|
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
|
|
| @app.route("/predict/batch", methods=["POST"]) |
| def predict_batch(): |
| """ |
| Predice mΓΊltiples imΓ‘genes en una sola llamada. |
| |
| Body JSON: |
| { |
| "images": [ |
| [0, 0, 128, ..., 255], // 784 valores por imagen |
| [...] |
| ] |
| } |
| |
| Respuesta JSON: |
| { |
| "results": [ |
| {"class_id": 7, "class_name": "Sneaker", "confidence": 0.94}, |
| ... |
| ], |
| "count": 2 |
| } |
| """ |
| if not MODEL_LOADED: |
| return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503 |
|
|
| data = request.get_json(force=True) |
| if not data or "images" not in data: |
| return jsonify({"error": "Se requiere el campo 'images' con una lista de arrays de 784 valores."}), 400 |
|
|
| images = data["images"] |
| if not isinstance(images, list) or len(images) == 0: |
| return jsonify({"error": "'images' debe ser una lista no vacΓa."}), 400 |
|
|
| errors = [] |
| for i, img in enumerate(images): |
| if len(img) != 784: |
| errors.append(f"Imagen [{i}]: se esperaban 784 pΓxeles, se recibieron {len(img)}.") |
| if errors: |
| return jsonify({"error": errors}), 400 |
|
|
| try: |
| t0 = time.perf_counter() |
| X = preprocess_batch(images) |
| preds = clf.predict(X) |
| probas = clf.predict_proba(X) |
| elapsed = round((time.perf_counter() - t0) * 1000, 2) |
|
|
| results = [] |
| for pred, proba in zip(preds, probas): |
| class_id = int(pred) |
| class_name = CLASS_NAMES[class_id] |
| confidence = round(float(proba[class_id]), 4) |
| results.append({ |
| "class_id": class_id, |
| "class_name": class_name, |
| "confidence": confidence, |
| }) |
|
|
| return jsonify({ |
| "results": results, |
| "count": len(results), |
| "inference_ms": elapsed, |
| }) |
|
|
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
|
|
|
|
| |
| if __name__ == "__main__": |
| print(f"\nπ Iniciando API en http://{HOST}:{PORT}") |
| print(" Endpoints disponibles:") |
| print(" GET /health") |
| print(" GET /model/info") |
| print(" POST /predict") |
| print(" POST /predict/batch\n") |
| app.run(host=HOST, port=PORT, debug=DEBUG) |
|
|