practica3 / app.py
patricp9's picture
Create app.py
d9027c0 verified
import gradio as gr
import numpy as np
import torch
from fastai.vision.all import *
REPO_ID = "TU_USUARIO/segmentation-unet-resnet50" # <- cámbialo
# Descarga el model.pkl del Hub y cárgalo
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
learn = load_learner(model_path, cpu=True)
def predict_mask(img):
# img llega como PIL
pred_mask, _, _ = learn.predict(img) # pred_mask es PILMask
arr = np.array(pred_mask).astype(np.uint8)
# Para visualizar mejor, lo devolvemos como imagen (0..255)
# (No intentamos colorear por clase para mantenerlo simple y robusto)
vis = (arr * (255 // max(1, arr.max()))).astype(np.uint8)
return vis
demo = gr.Interface(
fn=predict_mask,
inputs=gr.Image(type="pil", label="Imagen de entrada"),
outputs=gr.Image(type="numpy", label="Máscara predicha"),
title="Segmentación multiclase U-Net (FastAI)",
description="Sube una imagen y se genera la máscara predicha."
)
if __name__ == "__main__":
demo.launch()