import gradio as gr import torch import numpy as np import torchvision.transforms as T from PIL import Image from huggingface_hub import hf_hub_download # Descargar modelo desde HuggingFace Hub (igual que Práctica 2 carga desde el Hub) model_repo_id = "daniihc16/unet-grape-segmentation" try: model_path = hf_hub_download(repo_id=model_repo_id, filename="unet_multiclass_hf.pth") except Exception as e: print(f"Error descargando modelo de HuggingFace: {e}. Asegúrate de subir el modelo primero.") raise e device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.jit.load(model_path, map_location=device) model.eval() # Clases y colores class_names = ['background', 'leaves', 'wood', 'pole', 'grape'] class_colors = { 0: [0, 0, 0], # background - negro 1: [0, 128, 0], # leaves - verde 2: [139, 69, 19], # wood - marrón 3: [128, 128, 128], # pole - gris 4: [128, 0, 128], # grape - morado } def predict(image): if image is None: return None # Preprocesar igual que durante el entrenamiento transform = T.Compose([ T.Resize((480, 640)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) tensor = transform(image).unsqueeze(0).to(device) # Inferencia with torch.no_grad(): output = model(tensor) pred = torch.argmax(output, 1).cpu().numpy().reshape(480, 640) # Convertir máscara de clases a imagen RGB coloreada rgb = np.zeros((480, 640, 3), dtype=np.uint8) for cls_id, color in class_colors.items(): rgb[pred == cls_id] = color return Image.fromarray(rgb) iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Imagen de viñedo"), outputs=gr.Image(type="pil", label="Segmentación (background, leaves, wood, pole, grape)"), title="Segmentación Semántica de Viñedos", description=( "Modelo U-Net con backbone ResNet-50 entrenado para segmentar imágenes de viñedos. " "Clases: background (negro), leaves/hojas (verde), wood/madera (marrón), " "pole/poste (gris), grape/uva (morado). " f"Modelo cargado desde: daniihc16/unet-grape-segmentation" ), examples=['color_157.jpg', 'color_158.jpg'] ) if __name__ == "__main__": iface.launch()