ejemplo / app.py
RaulHuarote's picture
Upload 5 files
3b6f144 verified
# app.py
import io
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from tensorflow.keras.models import load_model
# Inicializar la aplicación FastAPI
app = FastAPI(
title="API de Clasificación de TBC (Mejorado)",
description="Una API para clasificar radiografías de tórax como 'normal' o 'tbc' usando un modelo CNN con Transfer Learning (MobileNetV2).",
version="2.0"
)
# --- Carga del Modelo ---
# Cargar el modelo .h5 mejorado al iniciar la aplicación.
try:
# Asegúrate de que el nombre del archivo coincida con el que guardaste
model = load_model('modelo_cnn_tbc_mejorado.h5')
print("Modelo mejorado cargado exitosamente.")
except Exception as e:
print(f"Error al cargar el modelo: {e}")
model = None
# Definir las constantes del nuevo modelo
IMG_HEIGHT = 224 # <-- ¡CAMBIO IMPORTANTE!
IMG_WIDTH = 224 # <-- ¡CAMBIO IMPORTANTE!
CLASS_NAMES = ["normal", "tbc"]
# --- Endpoints de la API ---
@app.get("/")
def read_root():
"""Endpoint raíz para verificar que la API está funcionando."""
return {"message": "Bienvenido a la API de Clasificación de TBC v2.0. Usa el endpoint /predict/."}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
Endpoint para predecir si una radiografía es 'normal' o 'tbc'.
- Acepta un archivo de imagen (JPG, PNG, etc.).
- Devuelve la clase predicha y la confianza de la predicción.
"""
if not model:
return JSONResponse(status_code=500, content={"error": "El modelo no está cargado. Revisa los logs del servidor."})
# 1. Leer el contenido del archivo subido en memoria
contents = await file.read()
# 2. Convertir los bytes en una imagen PIL
try:
image = Image.open(io.BytesIO(contents)).convert('RGB')
except Exception as e:
return JSONResponse(status_code=400, content={"error": f"Archivo inválido. No se pudo procesar la imagen: {e}"})
# 3. Preprocesar la imagen para que coincida con la entrada del modelo
# Redimensionar la imagen al tamaño esperado por MobileNetV2
image = image.resize((IMG_WIDTH, IMG_HEIGHT))
# Convertir la imagen a un array de numpy
img_array = np.array(image)
# Normalizar los valores de los píxeles (de 0-255 a 0-1)
img_array = img_array / 255.0
# Añadir una dimensión de batch (el modelo espera una forma de [1, 224, 224, 3])
image_batch = np.expand_dims(img_array, axis=0)
# 4. Realizar la predicción
prediction = model.predict(image_batch)
score = prediction[0][0]
# 5. Interpretar el resultado
if score > 0.5:
predicted_class = CLASS_NAMES[1] # 'tbc'
confidence = score
else:
predicted_class = CLASS_NAMES[0] # 'normal'
confidence = 1 - score
# 6. Devolver el resultado en formato JSON
return {
"filename": file.filename,
"prediction": predicted_class,
"confidence": float(confidence)
}