Spaces:
Sleeping
Sleeping
File size: 4,298 Bytes
f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a 0efb498 f93586a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from fastapi import FastAPI
from pydantic import BaseModel
import base64
import numpy as np
from PIL import Image
import io
import ai_edge_litert.interpreter as interpreter
app = FastAPI(title="AI Edge LiteRT API")
# Cargar el modelo TFLite una sola vez al iniciar
MODEL_PATH = "./my_classification_model_float16.tflite" # Cambia según tu modelo (float32, float16, int8, etc.)
litert_interpreter = interpreter.Interpreter(model_path=MODEL_PATH)
litert_interpreter.allocate_tensors()
# Obtener detalles de entrada/salida
input_details = litert_interpreter.get_input_details()
output_details = litert_interpreter.get_output_details()
# Verificar si el modelo usa cuantización INT8
IS_INT8_MODEL = input_details[0]['dtype'] == np.uint8
class ImagePayload(BaseModel):
image_base64: str
@app.get("/")
def home():
return {
"status": "ok",
"message": "API is running! Use POST /predict",
"model_info": {
"input_shape": input_details[0]['shape'].tolist(),
"input_dtype": str(input_details[0]['dtype']),
"output_shape": output_details[0]['shape'].tolist(),
"output_dtype": str(output_details[0]['dtype']),
"quantized": IS_INT8_MODEL
}
}
def preprocess_image(img_bytes, target_size=(224, 224)):
"""
Preprocesa la imagen usando NumPy y PIL
Args:
img_bytes: Bytes de la imagen
target_size: Tupla (height, width)
Returns:
Imagen preprocesada como numpy array
"""
# Decodificar imagen con PIL
img = Image.open(io.BytesIO(img_bytes))
# Convertir a RGB si es necesario
if img.mode != 'RGB':
img = img.convert('RGB')
# Redimensionar
img = img.resize(target_size, Image.BILINEAR)
# Convertir a numpy array
img_array = np.array(img, dtype=np.float32)
# Normalizar a [0, 1]
img_array = img_array / 255.0
# Expandir dimensiones para batch
img_array = np.expand_dims(img_array, axis=0)
# Si es modelo INT8, convertir directamente a uint8 [0, 255]
# El modelo internamente hace el escalado y zero point
if IS_INT8_MODEL:
# Volver a escala [0, 255] y convertir a uint8
img_array = (img_array).astype(np.uint8)
return img_array
def postprocess_output(output):
"""
Postprocesa la salida del modelo
Args:
output: Salida raw del modelo
Returns:
Probabilidades como lista
"""
# Si es modelo INT8, la salida ya está en uint8 [0, 255]
# El modelo internamente hace el descalado, solo necesitamos
# convertir de uint8 a float [0, 1] o [0, 255] dependiendo del caso
if IS_INT8_MODEL:
# Convertir de uint8 [0, 255] a float [0, 1]
output = output.astype(np.float32)
# El modelo ya tiene softmax, así que solo convertir a lista
return output[0].tolist()
@app.post("/predict")
def predict(payload: ImagePayload):
"""
Endpoint de predicción
Args:
payload: JSON con imagen en base64
Returns:
Predicciones del modelo
"""
try:
# Decodificar base64
img_bytes = base64.b64decode(payload.image_base64)
# Preprocesar imagen
img_array = preprocess_image(img_bytes, target_size=(224, 224))
# Inferencia con AI Edge LiteRT
litert_interpreter.set_tensor(input_details[0]['index'], img_array)
litert_interpreter.invoke()
output = litert_interpreter.get_tensor(output_details[0]['index'])
# Postprocesar salida
predictions = postprocess_output(output)
# Obtener clase predicha y confianza
predicted_class = int(np.argmax(predictions))
confidence = float(predictions[predicted_class])
return {
"prediction": predictions,
"predicted_class": predicted_class,
"confidence": confidence,
"top_5": sorted(
[(i, float(p)) for i, p in enumerate(predictions)],
key=lambda x: x[1],
reverse=True
)[:5]
}
except Exception as e:
return {
"error": str(e),
"status": "failed"
}
@app.get("/health")
def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": True}
|