Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| from PIL import Image | |
| import requests | |
| import os | |
| # URL del modelo en Hugging Face | |
| model_url = "https://huggingface.co/macapa/blindness_clas/resolve/main/blindness_model.pth" | |
| model_path = "best_model_resnet18.pth" | |
| hf_hub_download( | |
| repo_id='macapa/blindness_clas', | |
| filename='best_model_resnet18.pth', | |
| local_dir='.' | |
| ) | |
| # Cargar el modelo PyTorch | |
| model = torch.load(model_path, map_location=torch.device('cpu')) | |
| # model.eval() | |
| # Definir las transformaciones de la imagen | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Definir las etiquetas de clasificaci贸n | |
| labels = ["No Blindness", "Mild", "Moderate", "Severe", "Proliferative"] | |
| # Funci贸n para predecir la clase de ceguera | |
| def classify_image(img): | |
| img = preprocess(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| _, predicted = torch.max(outputs, 1) | |
| return labels[predicted.item()] | |
| # Definir la interfaz de Gradio | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(label="Carga una imagen aqu铆"), | |
| outputs=gr.Label(num_top_classes=1), | |
| title="Blindness Classification", | |
| description="Classify the severity of blindness from retinal images." | |
| ) | |
| # Ejecutar la aplicaci贸n | |
| interface.launch(share=True) | |