SMV / app.py
Tyan1988's picture
Subo modelo SMV completo
9885230
"""
app.py
======
API REST con Flask que expone el modelo SVM entrenado.
Endpoints:
GET /health β†’ Estado de la API y del modelo
GET /model/info β†’ InformaciΓ³n del modelo (clases, componentes PCA, etc.)
POST /predict β†’ Predice una imagen (array 784 pΓ­xeles, valores 0-255)
POST /predict/batch β†’ Predice mΓΊltiples imΓ‘genes a la vez
Iniciar el servidor:
python app.py
"""
import os
import pickle
import time
import numpy as np
from flask import Flask, jsonify, request
# ── ConfiguraciΓ³n ─────────────────────────────────────────────────────────────
MODELS_DIR = "models"
HOST = "0.0.0.0"
PORT = 5000
DEBUG = True # Cambiar a False en producciΓ³n
CLASS_NAMES = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
app = Flask(__name__)
# ── Cargar modelos al iniciar ─────────────────────────────────────────────────
def load_artifacts():
scaler_path = os.path.join(MODELS_DIR, "scaler.pkl")
pca_path = os.path.join(MODELS_DIR, "pca.pkl")
model_path = os.path.join(MODELS_DIR, "svm_model.pkl")
missing = [p for p in (scaler_path, pca_path, model_path) if not os.path.exists(p)]
if missing:
raise FileNotFoundError(
f"Artefactos no encontrados: {missing}\n"
"Ejecuta primero: python train_model.py"
)
with open(scaler_path, "rb") as f:
scaler = pickle.load(f)
with open(pca_path, "rb") as f:
pca = pickle.load(f)
with open(model_path, "rb") as f:
clf = pickle.load(f)
return scaler, pca, clf
print("⏳ Cargando artefactos del modelo...")
try:
scaler, pca, clf = load_artifacts()
MODEL_LOADED = True
print("βœ… Modelo cargado correctamente.")
except FileNotFoundError as e:
MODEL_LOADED = False
print(f"❌ {e}")
# ── Helpers ───────────────────────────────────────────────────────────────────
def preprocess(pixels: list) -> np.ndarray:
"""Escala y aplica PCA a un vector de 784 pΓ­xeles."""
arr = np.array(pixels, dtype=np.float64).reshape(1, -1)
arr = scaler.transform(arr)
arr = pca.transform(arr)
return arr
def preprocess_batch(batch: list) -> np.ndarray:
"""Escala y aplica PCA a un lote de vectores."""
arr = np.array(batch, dtype=np.float64)
arr = scaler.transform(arr)
arr = pca.transform(arr)
return arr
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.route("/health", methods=["GET"])
def health():
"""Verificar que la API y el modelo estΓ‘n operativos."""
return jsonify({
"status": "ok" if MODEL_LOADED else "degraded",
"model_loaded": MODEL_LOADED,
"message": "API operativa" if MODEL_LOADED
else "Modelo no cargado. Ejecuta train_model.py",
}), 200 if MODEL_LOADED else 503
@app.route("/model/info", methods=["GET"])
def model_info():
"""Devuelve metadatos del modelo cargado."""
if not MODEL_LOADED:
return jsonify({"error": "Modelo no disponible"}), 503
return jsonify({
"model_type": type(clf).__name__,
"kernel": clf.kernel,
"C": clf.C,
"gamma": clf.gamma,
"classes": CLASS_NAMES,
"n_classes": len(CLASS_NAMES),
"pca_components": int(pca.n_components_),
"pca_variance_ratio": float(np.sum(pca.explained_variance_ratio_)),
"input_features": 784,
})
@app.route("/predict", methods=["POST"])
def predict():
"""
Predice la clase de una imagen de Fashion-MNIST.
Body JSON:
{
"pixels": [0, 0, 128, ..., 255] // lista de 784 valores int 0-255
}
Respuesta JSON:
{
"class_id": 7,
"class_name": "Sneaker",
"confidence": 0.94,
"probabilities": { "T-shirt/top": 0.01, ... }
}
"""
if not MODEL_LOADED:
return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503
data = request.get_json(force=True)
if not data or "pixels" not in data:
return jsonify({"error": "Se requiere el campo 'pixels' con 784 valores."}), 400
pixels = data["pixels"]
if len(pixels) != 784:
return jsonify({"error": f"Se esperaban 784 pΓ­xeles, se recibieron {len(pixels)}."}), 400
try:
t0 = time.perf_counter()
X = preprocess(pixels)
pred = clf.predict(X)[0]
proba = clf.predict_proba(X)[0]
elapsed = round((time.perf_counter() - t0) * 1000, 2)
class_id = int(pred)
class_name = CLASS_NAMES[class_id]
confidence = round(float(proba[class_id]), 4)
prob_dict = {CLASS_NAMES[i]: round(float(p), 4) for i, p in enumerate(proba)}
return jsonify({
"class_id": class_id,
"class_name": class_name,
"confidence": confidence,
"probabilities": prob_dict,
"inference_ms": elapsed,
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/predict/batch", methods=["POST"])
def predict_batch():
"""
Predice mΓΊltiples imΓ‘genes en una sola llamada.
Body JSON:
{
"images": [
[0, 0, 128, ..., 255], // 784 valores por imagen
[...]
]
}
Respuesta JSON:
{
"results": [
{"class_id": 7, "class_name": "Sneaker", "confidence": 0.94},
...
],
"count": 2
}
"""
if not MODEL_LOADED:
return jsonify({"error": "Modelo no disponible. Ejecuta train_model.py"}), 503
data = request.get_json(force=True)
if not data or "images" not in data:
return jsonify({"error": "Se requiere el campo 'images' con una lista de arrays de 784 valores."}), 400
images = data["images"]
if not isinstance(images, list) or len(images) == 0:
return jsonify({"error": "'images' debe ser una lista no vacΓ­a."}), 400
errors = []
for i, img in enumerate(images):
if len(img) != 784:
errors.append(f"Imagen [{i}]: se esperaban 784 pΓ­xeles, se recibieron {len(img)}.")
if errors:
return jsonify({"error": errors}), 400
try:
t0 = time.perf_counter()
X = preprocess_batch(images)
preds = clf.predict(X)
probas = clf.predict_proba(X)
elapsed = round((time.perf_counter() - t0) * 1000, 2)
results = []
for pred, proba in zip(preds, probas):
class_id = int(pred)
class_name = CLASS_NAMES[class_id]
confidence = round(float(proba[class_id]), 4)
results.append({
"class_id": class_id,
"class_name": class_name,
"confidence": confidence,
})
return jsonify({
"results": results,
"count": len(results),
"inference_ms": elapsed,
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# ── Main ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
print(f"\n🌐 Iniciando API en http://{HOST}:{PORT}")
print(" Endpoints disponibles:")
print(" GET /health")
print(" GET /model/info")
print(" POST /predict")
print(" POST /predict/batch\n")
app.run(host=HOST, port=PORT, debug=DEBUG)