Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import zipfile | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from tensorflow.keras.models import load_model | |
| import keras | |
| # ----------------------------- | |
| # Configuración del modelo | |
| # ----------------------------- | |
| REPO_ID = "nicolasrl/matriculas" # tu repo en HF | |
| ZIP_FILE = "matriculas_resnetv2.zip" # archivo zip subido con todo el SavedModel | |
| # ----------------------------- | |
| # Función para descargar y cargar modelo | |
| # ----------------------------- | |
| def get_model(): | |
| """ | |
| Descarga el modelo en zip desde Hugging Face, lo descomprime | |
| y lo carga como SavedModel de TensorFlow/Keras. | |
| """ | |
| tmp_dir = "/tmp/matriculas_resnetv2" | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| # Descargar el zip desde HF | |
| zip_path = hf_hub_download(repo_id=REPO_ID, filename=ZIP_FILE) | |
| # Descomprimir en carpeta temporal | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(tmp_dir) | |
| # Habilitar deserialización insegura (si hay Lambda layers) | |
| keras.config.enable_unsafe_deserialization() | |
| # Cargar el modelo desde la carpeta descomprimida | |
| model = load_model(tmp_dir, safe_mode=False) | |
| return model | |
| # ----------------------------- | |
| # Cargar modelo al iniciar app | |
| # ----------------------------- | |
| model = get_model() | |
| # ----------------------------- | |
| # Función de predicción | |
| # ----------------------------- | |
| def predict(img: Image.Image): | |
| """ | |
| Recibe imagen PIL, realiza inferencia y devuelve imagen | |
| anotada con la caja y texto con coordenadas. | |
| """ | |
| pil = img.convert("RGB") | |
| w, h = pil.size | |
| # Preprocesamiento igual al usado en entrenamiento | |
| img_resized = pil.resize((224, 224)) | |
| x = np.array(img_resized).astype("float32") / 255.0 | |
| x = np.expand_dims(x, 0) | |
| # Inferencia | |
| pred = model.predict(x)[0] # [nxmin, nxmax, nymin, nymax] | |
| nxmin, nxmax, nymin, nymax = pred | |
| # Convertir coordenadas normalizadas a píxeles | |
| xmin = int(max(0, nxmin * w)) | |
| xmax = int(min(w, nxmax * w)) | |
| ymin = int(max(0, nymin * h)) | |
| ymax = int(min(h, nymax * h)) | |
| # Dibujar caja | |
| annotated = pil.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| draw.rectangle([xmin, ymin, xmax, ymax], outline="cyan", width=3) | |
| return annotated, f"{xmin},{ymin},{xmax},{ymax}" | |
| # ----------------------------- | |
| # Interfaz Gradio | |
| # ----------------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="pil"), gr.Textbox()], | |
| title="Detección de matrículas", | |
| description="Sube una imagen. El modelo devuelve la caja delimitadora (xmin,ymin,xmax,ymax)." | |
| ) | |
| # ----------------------------- | |
| # Lanzar app | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, enable_queue=True) | |