|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| import torchvision.transforms as T |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
|
|
| |
| model_repo_id = "daniihc16/unet-grape-segmentation" |
|
|
| try: |
| model_path = hf_hub_download(repo_id=model_repo_id, filename="unet_multiclass_hf.pth") |
| except Exception as e: |
| print(f"Error descargando modelo de HuggingFace: {e}. Asegúrate de subir el modelo primero.") |
| raise e |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = torch.jit.load(model_path, map_location=device) |
| model.eval() |
|
|
| |
| class_names = ['background', 'leaves', 'wood', 'pole', 'grape'] |
| class_colors = { |
| 0: [0, 0, 0], |
| 1: [0, 128, 0], |
| 2: [139, 69, 19], |
| 3: [128, 128, 128], |
| 4: [128, 0, 128], |
| } |
|
|
| def predict(image): |
| if image is None: |
| return None |
|
|
| |
| transform = T.Compose([ |
| T.Resize((480, 640)), |
| T.ToTensor(), |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
| tensor = transform(image).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| output = model(tensor) |
| pred = torch.argmax(output, 1).cpu().numpy().reshape(480, 640) |
|
|
| |
| rgb = np.zeros((480, 640, 3), dtype=np.uint8) |
| for cls_id, color in class_colors.items(): |
| rgb[pred == cls_id] = color |
|
|
| return Image.fromarray(rgb) |
|
|
| iface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Imagen de viñedo"), |
| outputs=gr.Image(type="pil", label="Segmentación (background, leaves, wood, pole, grape)"), |
| title="Segmentación Semántica de Viñedos", |
| description=( |
| "Modelo U-Net con backbone ResNet-50 entrenado para segmentar imágenes de viñedos. " |
| "Clases: background (negro), leaves/hojas (verde), wood/madera (marrón), " |
| "pole/poste (gris), grape/uva (morado). " |
| f"Modelo cargado desde: daniihc16/unet-grape-segmentation" |
| ), |
| examples=['color_157.jpg', 'color_158.jpg'] |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|