pictionary / app.py
alexvc99's picture
Upload 3 files
6da2c2f verified
from pathlib import Path
from PIL import Image
from torch import nn
import torch
import gradio as gr
import numpy as np
# Leemos las etiquetas de clases (categorías) desde un fichero de texto
LABELS = Path('class_names.txt').read_text().splitlines()
# Definimos la arquitectura de la red neuronal convolucional (CNN) ya entrenada:
model = nn.Sequential(
# Primera capa: 1 canal de entrada, 32 filtros, tamaño de filtro 3x3
nn.Conv2d(1, 32, 3, padding='same'),
# Función de activación no lineal ReLU (acelera y facilita el aprendizaje)
nn.ReLU(),
# Max Pooling: reduce la resolución espacial de las características
# (comprime la imagen a la vez que mantiene zonas más “activas”)
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'), # Segunda capa: 32→64 filtros
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),# Tercera capa: 64→128 filtros
nn.ReLU(),
nn.MaxPool2d(2),
# Aplana los datos resultantes para prepararlos para las capas
# densas (total elementos = 128 canales * 3 * 3)
nn.Flatten(),
# Capa totalmente conectada: de 1152 (productos anteriores)
# a 256 neuronas
nn.Linear(1152, 256),
nn.ReLU(),
# Capa de salida: 1 neurona por clase del archivo de etiquetas
nn.Linear(256, len(LABELS)),
)
# Cargamos los pesos previamente entrenados del modelo
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
# Ponemos el modelo en modo inferencia (no entrenamiento)
model.eval()
# Función principal de predicción, procesará el dibujo
# de Gradio y calculará su clase
def predict(img):
# Si no hay dibujo o la clave 'composite' no existe o está vacía, avisamos:
if img is None or "composite" not in img or img["composite"] is None:
return {"Por favor, dibuja algo": 1.0}
# Extraemos la imagen resultado del canvas, canal RGBA
# Array con forma (ej. [800, 800, 4]), tipo uint8
arr = img["composite"]
# Convertimos de RGBA a escala de grises (Quick Draw es gris)
arr_gray = arr[..., :3].mean(axis=2)
# Convertimos a uint8 por si PIL lo necesita
arr_gray_uint8 = arr_gray.astype("uint8")
# Redimensionamos a 28x28 píxeles (tamaño de entrada del modelo)
arr_img = Image.fromarray(arr_gray_uint8)
arr_resized = np.array(arr_img.resize((28, 28), resample=Image.BILINEAR))
# Escalamos a rango [0,1]
arr_normalized = arr_resized / 255.0
# Añadimos dimensiones de batch y canal: (1, 1, 28, 28)
x = torch.tensor(arr_normalized, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# Ejecutamos inferencia sin calcular gradientes (más eficiente)
with torch.no_grad():
out = model(x)
# Calculamos probabilidades con softmax
probabilities = torch.nn.functional.softmax(out[0], dim=0)
# Obtenemos las 5 clases más probables (top-5)
values, indices = torch.topk(probabilities, 5)
# Devolvemos un diccionario: categoría : probabilidad (~confianza)
return {LABELS[i]: v.item() for i, v in zip(indices, values)}
# Creamos la interfaz Gradio:
# - El input es un sketchpad (zona para dibujar)
# - El output son etiquetas: las categorías predecidas
# - live=True: actualiza la predicción en tiempo real al dibujar
demo = gr.Interface(
predict,
inputs='sketchpad',
outputs='label',
live=True)
demo.launch(share=True)