from pathlib import Path from PIL import Image from torch import nn import torch import gradio as gr import numpy as np # Leemos las etiquetas de clases (categorías) desde un fichero de texto LABELS = Path('class_names.txt').read_text().splitlines() # Definimos la arquitectura de la red neuronal convolucional (CNN) ya entrenada: model = nn.Sequential( # Primera capa: 1 canal de entrada, 32 filtros, tamaño de filtro 3x3 nn.Conv2d(1, 32, 3, padding='same'), # Función de activación no lineal ReLU (acelera y facilita el aprendizaje) nn.ReLU(), # Max Pooling: reduce la resolución espacial de las características # (comprime la imagen a la vez que mantiene zonas más “activas”) nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding='same'), # Segunda capa: 32→64 filtros nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding='same'),# Tercera capa: 64→128 filtros nn.ReLU(), nn.MaxPool2d(2), # Aplana los datos resultantes para prepararlos para las capas # densas (total elementos = 128 canales * 3 * 3) nn.Flatten(), # Capa totalmente conectada: de 1152 (productos anteriores) # a 256 neuronas nn.Linear(1152, 256), nn.ReLU(), # Capa de salida: 1 neurona por clase del archivo de etiquetas nn.Linear(256, len(LABELS)), ) # Cargamos los pesos previamente entrenados del modelo state_dict = torch.load('pytorch_model.bin', map_location='cpu') model.load_state_dict(state_dict, strict=False) # Ponemos el modelo en modo inferencia (no entrenamiento) model.eval() # Función principal de predicción, procesará el dibujo # de Gradio y calculará su clase def predict(img): # Si no hay dibujo o la clave 'composite' no existe o está vacía, avisamos: if img is None or "composite" not in img or img["composite"] is None: return {"Por favor, dibuja algo": 1.0} # Extraemos la imagen resultado del canvas, canal RGBA # Array con forma (ej. [800, 800, 4]), tipo uint8 arr = img["composite"] # Convertimos de RGBA a escala de grises (Quick Draw es gris) arr_gray = arr[..., :3].mean(axis=2) # Convertimos a uint8 por si PIL lo necesita arr_gray_uint8 = arr_gray.astype("uint8") # Redimensionamos a 28x28 píxeles (tamaño de entrada del modelo) arr_img = Image.fromarray(arr_gray_uint8) arr_resized = np.array(arr_img.resize((28, 28), resample=Image.BILINEAR)) # Escalamos a rango [0,1] arr_normalized = arr_resized / 255.0 # Añadimos dimensiones de batch y canal: (1, 1, 28, 28) x = torch.tensor(arr_normalized, dtype=torch.float32).unsqueeze(0).unsqueeze(0) # Ejecutamos inferencia sin calcular gradientes (más eficiente) with torch.no_grad(): out = model(x) # Calculamos probabilidades con softmax probabilities = torch.nn.functional.softmax(out[0], dim=0) # Obtenemos las 5 clases más probables (top-5) values, indices = torch.topk(probabilities, 5) # Devolvemos un diccionario: categoría : probabilidad (~confianza) return {LABELS[i]: v.item() for i, v in zip(indices, values)} # Creamos la interfaz Gradio: # - El input es un sketchpad (zona para dibujar) # - El output son etiquetas: las categorías predecidas # - live=True: actualiza la predicción en tiempo real al dibujar demo = gr.Interface( predict, inputs='sketchpad', outputs='label', live=True) demo.launch(share=True)