Mikeztrada commited on
Commit
b2d59e6
·
verified ·
1 Parent(s): 1921e5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -6,17 +6,22 @@ import numpy as np
6
  model = tf.keras.models.load_model("quickdraw_model.keras")
7
 
8
  def predict(image):
9
- # La imagen llega como un array (28x28) o RGB, normalizamos y aplanamos
10
- image = image[...,0] if image.ndim == 3 else image # si es RGB, toma un canal
 
 
 
 
 
11
  image = image / 255.0
12
  image = image.reshape(1, 784)
13
  preds = model.predict(image)
14
  class_idx = np.argmax(preds)
15
- return class_idx # puedes mapear a etiquetas si quieres
16
 
17
  iface = gr.Interface(
18
  fn=predict,
19
- inputs=gr.Image(shape=(28, 28), image_mode='L', source='upload'),
20
  outputs="label",
21
  title="QuickDraw API",
22
  description="API para reconocer dibujos estilo QuickDraw"
 
6
  model = tf.keras.models.load_model("quickdraw_model.keras")
7
 
8
  def predict(image):
9
+ import cv2
10
+ import numpy as np
11
+ # Si la imagen es RGB, conviértela a escala de grises
12
+ if image.ndim == 3:
13
+ image = np.mean(image, axis=2)
14
+ # Redimensiona a 28x28
15
+ image = cv2.resize(image, (28, 28))
16
  image = image / 255.0
17
  image = image.reshape(1, 784)
18
  preds = model.predict(image)
19
  class_idx = np.argmax(preds)
20
+ return str(class_idx)
21
 
22
  iface = gr.Interface(
23
  fn=predict,
24
+ inputs=gr.Image(image_mode='L', source='upload', tool=None),
25
  outputs="label",
26
  title="QuickDraw API",
27
  description="API para reconocer dibujos estilo QuickDraw"