nicolasrl's picture
Update app.py
a448502 verified
# app.py
import os
import zipfile
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
from huggingface_hub import hf_hub_download
from tensorflow.keras.models import load_model
import keras
# -----------------------------
# Configuración del modelo
# -----------------------------
REPO_ID = "nicolasrl/matriculas" # tu repo en HF
ZIP_FILE = "matriculas_resnetv2.zip" # archivo zip subido con todo el SavedModel
# -----------------------------
# Función para descargar y cargar modelo
# -----------------------------
def get_model():
"""
Descarga el modelo en zip desde Hugging Face, lo descomprime
y lo carga como SavedModel de TensorFlow/Keras.
"""
tmp_dir = "/tmp/matriculas_resnetv2"
os.makedirs(tmp_dir, exist_ok=True)
# Descargar el zip desde HF
zip_path = hf_hub_download(repo_id=REPO_ID, filename=ZIP_FILE)
# Descomprimir en carpeta temporal
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
# Habilitar deserialización insegura (si hay Lambda layers)
keras.config.enable_unsafe_deserialization()
# Cargar el modelo desde la carpeta descomprimida
model = load_model(tmp_dir, safe_mode=False)
return model
# -----------------------------
# Cargar modelo al iniciar app
# -----------------------------
model = get_model()
# -----------------------------
# Función de predicción
# -----------------------------
def predict(img: Image.Image):
"""
Recibe imagen PIL, realiza inferencia y devuelve imagen
anotada con la caja y texto con coordenadas.
"""
pil = img.convert("RGB")
w, h = pil.size
# Preprocesamiento igual al usado en entrenamiento
img_resized = pil.resize((224, 224))
x = np.array(img_resized).astype("float32") / 255.0
x = np.expand_dims(x, 0)
# Inferencia
pred = model.predict(x)[0] # [nxmin, nxmax, nymin, nymax]
nxmin, nxmax, nymin, nymax = pred
# Convertir coordenadas normalizadas a píxeles
xmin = int(max(0, nxmin * w))
xmax = int(min(w, nxmax * w))
ymin = int(max(0, nymin * h))
ymax = int(min(h, nymax * h))
# Dibujar caja
annotated = pil.copy()
draw = ImageDraw.Draw(annotated)
draw.rectangle([xmin, ymin, xmax, ymax], outline="cyan", width=3)
return annotated, f"{xmin},{ymin},{xmax},{ymax}"
# -----------------------------
# Interfaz Gradio
# -----------------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil"), gr.Textbox()],
title="Detección de matrículas",
description="Sube una imagen. El modelo devuelve la caja delimitadora (xmin,ymin,xmax,ymax)."
)
# -----------------------------
# Lanzar app
# -----------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, enable_queue=True)