Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from PIL import Image | |
| from torch import nn | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| # Leemos las etiquetas de clases (categorías) desde un fichero de texto | |
| LABELS = Path('class_names.txt').read_text().splitlines() | |
| # Definimos la arquitectura de la red neuronal convolucional (CNN) ya entrenada: | |
| model = nn.Sequential( | |
| # Primera capa: 1 canal de entrada, 32 filtros, tamaño de filtro 3x3 | |
| nn.Conv2d(1, 32, 3, padding='same'), | |
| # Función de activación no lineal ReLU (acelera y facilita el aprendizaje) | |
| nn.ReLU(), | |
| # Max Pooling: reduce la resolución espacial de las características | |
| # (comprime la imagen a la vez que mantiene zonas más “activas”) | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding='same'), # Segunda capa: 32→64 filtros | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding='same'),# Tercera capa: 64→128 filtros | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| # Aplana los datos resultantes para prepararlos para las capas | |
| # densas (total elementos = 128 canales * 3 * 3) | |
| nn.Flatten(), | |
| # Capa totalmente conectada: de 1152 (productos anteriores) | |
| # a 256 neuronas | |
| nn.Linear(1152, 256), | |
| nn.ReLU(), | |
| # Capa de salida: 1 neurona por clase del archivo de etiquetas | |
| nn.Linear(256, len(LABELS)), | |
| ) | |
| # Cargamos los pesos previamente entrenados del modelo | |
| state_dict = torch.load('pytorch_model.bin', map_location='cpu') | |
| model.load_state_dict(state_dict, strict=False) | |
| # Ponemos el modelo en modo inferencia (no entrenamiento) | |
| model.eval() | |
| # Función principal de predicción, procesará el dibujo | |
| # de Gradio y calculará su clase | |
| def predict(img): | |
| # Si no hay dibujo o la clave 'composite' no existe o está vacía, avisamos: | |
| if img is None or "composite" not in img or img["composite"] is None: | |
| return {"Por favor, dibuja algo": 1.0} | |
| # Extraemos la imagen resultado del canvas, canal RGBA | |
| # Array con forma (ej. [800, 800, 4]), tipo uint8 | |
| arr = img["composite"] | |
| # Convertimos de RGBA a escala de grises (Quick Draw es gris) | |
| arr_gray = arr[..., :3].mean(axis=2) | |
| # Convertimos a uint8 por si PIL lo necesita | |
| arr_gray_uint8 = arr_gray.astype("uint8") | |
| # Redimensionamos a 28x28 píxeles (tamaño de entrada del modelo) | |
| arr_img = Image.fromarray(arr_gray_uint8) | |
| arr_resized = np.array(arr_img.resize((28, 28), resample=Image.BILINEAR)) | |
| # Escalamos a rango [0,1] | |
| arr_normalized = arr_resized / 255.0 | |
| # Añadimos dimensiones de batch y canal: (1, 1, 28, 28) | |
| x = torch.tensor(arr_normalized, dtype=torch.float32).unsqueeze(0).unsqueeze(0) | |
| # Ejecutamos inferencia sin calcular gradientes (más eficiente) | |
| with torch.no_grad(): | |
| out = model(x) | |
| # Calculamos probabilidades con softmax | |
| probabilities = torch.nn.functional.softmax(out[0], dim=0) | |
| # Obtenemos las 5 clases más probables (top-5) | |
| values, indices = torch.topk(probabilities, 5) | |
| # Devolvemos un diccionario: categoría : probabilidad (~confianza) | |
| return {LABELS[i]: v.item() for i, v in zip(indices, values)} | |
| # Creamos la interfaz Gradio: | |
| # - El input es un sketchpad (zona para dibujar) | |
| # - El output son etiquetas: las categorías predecidas | |
| # - live=True: actualiza la predicción en tiempo real al dibujar | |
| demo = gr.Interface( | |
| predict, | |
| inputs='sketchpad', | |
| outputs='label', | |
| live=True) | |
| demo.launch(share=True) |