alexvc99 commited on
Commit
6da2c2f
·
verified ·
1 Parent(s): 025525c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. class_names.txt +100 -0
  3. pytorch_model.bin +3 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ from torch import nn
4
+
5
+ import torch
6
+ import gradio as gr
7
+ import numpy as np
8
+
9
+ # Leemos las etiquetas de clases (categorías) desde un fichero de texto
10
+ LABELS = Path('class_names.txt').read_text().splitlines()
11
+
12
+ # Definimos la arquitectura de la red neuronal convolucional (CNN) ya entrenada:
13
+ model = nn.Sequential(
14
+ # Primera capa: 1 canal de entrada, 32 filtros, tamaño de filtro 3x3
15
+ nn.Conv2d(1, 32, 3, padding='same'),
16
+ # Función de activación no lineal ReLU (acelera y facilita el aprendizaje)
17
+ nn.ReLU(),
18
+ # Max Pooling: reduce la resolución espacial de las características
19
+ # (comprime la imagen a la vez que mantiene zonas más “activas”)
20
+ nn.MaxPool2d(2),
21
+ nn.Conv2d(32, 64, 3, padding='same'), # Segunda capa: 32→64 filtros
22
+ nn.ReLU(),
23
+ nn.MaxPool2d(2),
24
+ nn.Conv2d(64, 128, 3, padding='same'),# Tercera capa: 64→128 filtros
25
+ nn.ReLU(),
26
+ nn.MaxPool2d(2),
27
+ # Aplana los datos resultantes para prepararlos para las capas
28
+ # densas (total elementos = 128 canales * 3 * 3)
29
+ nn.Flatten(),
30
+ # Capa totalmente conectada: de 1152 (productos anteriores)
31
+ # a 256 neuronas
32
+ nn.Linear(1152, 256),
33
+ nn.ReLU(),
34
+ # Capa de salida: 1 neurona por clase del archivo de etiquetas
35
+ nn.Linear(256, len(LABELS)),
36
+ )
37
+ # Cargamos los pesos previamente entrenados del modelo
38
+ state_dict = torch.load('pytorch_model.bin', map_location='cpu')
39
+ model.load_state_dict(state_dict, strict=False)
40
+ # Ponemos el modelo en modo inferencia (no entrenamiento)
41
+ model.eval()
42
+
43
+ # Función principal de predicción, procesará el dibujo
44
+ # de Gradio y calculará su clase
45
+ def predict(img):
46
+ # Si no hay dibujo o la clave 'composite' no existe o está vacía, avisamos:
47
+ if img is None or "composite" not in img or img["composite"] is None:
48
+ return {"Por favor, dibuja algo": 1.0}
49
+ # Extraemos la imagen resultado del canvas, canal RGBA
50
+ # Array con forma (ej. [800, 800, 4]), tipo uint8
51
+ arr = img["composite"]
52
+ # Convertimos de RGBA a escala de grises (Quick Draw es gris)
53
+ arr_gray = arr[..., :3].mean(axis=2)
54
+ # Convertimos a uint8 por si PIL lo necesita
55
+ arr_gray_uint8 = arr_gray.astype("uint8")
56
+ # Redimensionamos a 28x28 píxeles (tamaño de entrada del modelo)
57
+ arr_img = Image.fromarray(arr_gray_uint8)
58
+ arr_resized = np.array(arr_img.resize((28, 28), resample=Image.BILINEAR))
59
+ # Escalamos a rango [0,1]
60
+ arr_normalized = arr_resized / 255.0
61
+ # Añadimos dimensiones de batch y canal: (1, 1, 28, 28)
62
+ x = torch.tensor(arr_normalized, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
63
+ # Ejecutamos inferencia sin calcular gradientes (más eficiente)
64
+ with torch.no_grad():
65
+ out = model(x)
66
+ # Calculamos probabilidades con softmax
67
+ probabilities = torch.nn.functional.softmax(out[0], dim=0)
68
+ # Obtenemos las 5 clases más probables (top-5)
69
+ values, indices = torch.topk(probabilities, 5)
70
+ # Devolvemos un diccionario: categoría : probabilidad (~confianza)
71
+ return {LABELS[i]: v.item() for i, v in zip(indices, values)}
72
+
73
+ # Creamos la interfaz Gradio:
74
+ # - El input es un sketchpad (zona para dibujar)
75
+ # - El output son etiquetas: las categorías predecidas
76
+ # - live=True: actualiza la predicción en tiempo real al dibujar
77
+ demo = gr.Interface(
78
+ predict,
79
+ inputs='sketchpad',
80
+ outputs='label',
81
+ live=True)
82
+
83
+ demo.launch(share=True)
class_names.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ airplane
2
+ alarm_clock
3
+ anvil
4
+ apple
5
+ axe
6
+ baseball
7
+ baseball_bat
8
+ basketball
9
+ beard
10
+ bed
11
+ bench
12
+ bicycle
13
+ bird
14
+ book
15
+ bread
16
+ bridge
17
+ broom
18
+ butterfly
19
+ camera
20
+ candle
21
+ car
22
+ cat
23
+ ceiling_fan
24
+ cell_phone
25
+ chair
26
+ circle
27
+ clock
28
+ cloud
29
+ coffee_cup
30
+ cookie
31
+ cup
32
+ diving_board
33
+ donut
34
+ door
35
+ drums
36
+ dumbbell
37
+ envelope
38
+ eye
39
+ eyeglasses
40
+ face
41
+ fan
42
+ flower
43
+ frying_pan
44
+ grapes
45
+ hammer
46
+ hat
47
+ headphones
48
+ helmet
49
+ hot_dog
50
+ ice_cream
51
+ key
52
+ knife
53
+ ladder
54
+ laptop
55
+ light_bulb
56
+ lightning
57
+ line
58
+ lollipop
59
+ microphone
60
+ moon
61
+ mountain
62
+ moustache
63
+ mushroom
64
+ pants
65
+ paper_clip
66
+ pencil
67
+ pillow
68
+ pizza
69
+ power_outlet
70
+ radio
71
+ rainbow
72
+ rifle
73
+ saw
74
+ scissors
75
+ screwdriver
76
+ shorts
77
+ shovel
78
+ smiley_face
79
+ snake
80
+ sock
81
+ spider
82
+ spoon
83
+ square
84
+ star
85
+ stop_sign
86
+ suitcase
87
+ sun
88
+ sword
89
+ syringe
90
+ t-shirt
91
+ table
92
+ tennis_racquet
93
+ tent
94
+ tooth
95
+ traffic_light
96
+ tree
97
+ triangle
98
+ umbrella
99
+ wheel
100
+ wristwatch
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effb6ea6f1593c09e8247944028ed9c309b5ff1cef82ba38b822bee2ca4d0f3c
3
+ size 1656903