GiladtheFixer commited on
Commit
a8fadd0
verified
1 Parent(s): c0013b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -1,32 +1,48 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
- from tensorflow import keras
4
  import numpy as np
5
  from huggingface_hub import hf_hub_download
6
 
7
- # 讛讜专讚转 讜讛讟注谞转 讛诪讜讚诇
8
  model_path = hf_hub_download(
9
  repo_id="GiladtheFixer/my_mnist_model",
10
  filename="mnist_model.keras"
11
  )
12
- # 讛砖讬谞讜讬 讛讬讞讬讚 - 讛讜住驻谞讜 compile=False
13
- model = tf.keras.models.load_model(model_path, compile=False)
14
 
15
- # 砖讗专 讛拽讜讚 谞砖讗专 讘讚讬讜拽 讗讜转讜 讚讘专
 
 
 
 
 
 
 
 
 
 
16
  def predict_digit(sketch_data):
 
17
  img = sketch_data["composite"]
18
  alpha_channel = img[..., 3]
19
  img = alpha_channel / 255.0
 
 
20
  resized = tf.image.resize(
21
  tf.expand_dims(img, -1),
22
  [28, 28],
23
  method='bilinear'
24
  )
25
- resized = tf.squeeze(resized)
26
- input_data = resized.numpy().reshape(1, 28, 28)
 
 
 
27
  pred = model.predict(input_data, verbose=0)
 
 
28
  return {str(i): float(pred[0][i]) for i in range(10)}
29
 
 
30
  demo = gr.Interface(
31
  fn=predict_digit,
32
  inputs=[
@@ -34,13 +50,13 @@ demo = gr.Interface(
34
  label="draw some digit",
35
  height=400,
36
  width=400,
37
- brush=None,
38
  interactive=True
39
  )
40
  ],
41
  outputs=gr.Label(num_top_classes=3),
42
- title="MNIST_by Gilad",
43
- description="draw some digit with brush or clear your board then click submit",
44
  allow_flagging="never"
45
  )
46
 
 
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  import numpy as np
4
  from huggingface_hub import hf_hub_download
5
 
6
+ # 讟注讬谞转 讛诪讜讚诇
7
  model_path = hf_hub_download(
8
  repo_id="GiladtheFixer/my_mnist_model",
9
  filename="mnist_model.keras"
10
  )
 
 
11
 
12
+ # 讟注讬谞转 讛诪讜讚诇 讘爪讜专讛 诪讜转讗诪转
13
+ model = tf.keras.models.load_model(
14
+ model_path,
15
+ compile=False,
16
+ custom_objects={
17
+ 'RandomRotation': tf.keras.layers.RandomRotation,
18
+ 'RandomZoom': tf.keras.layers.RandomZoom,
19
+ 'RandomTranslation': tf.keras.layers.RandomTranslation
20
+ }
21
+ )
22
+
23
  def predict_digit(sketch_data):
24
+ # 注讬讘讜讚 讛转诪讜谞讛
25
  img = sketch_data["composite"]
26
  alpha_channel = img[..., 3]
27
  img = alpha_channel / 255.0
28
+
29
+ # 砖讬谞讜讬 讙讜讚诇 讜讛讜住驻转 诪诪讚
30
  resized = tf.image.resize(
31
  tf.expand_dims(img, -1),
32
  [28, 28],
33
  method='bilinear'
34
  )
35
+
36
+ # 讛讻谞转 讛拽诇讟 诇诪讜讚诇
37
+ input_data = tf.expand_dims(resized, 0) # 讛讜住驻转 诪诪讚 讛讗爪讜讜讛
38
+
39
+ # 讞讬讝讜讬
40
  pred = model.predict(input_data, verbose=0)
41
+
42
+ # 讛讞讝专转 讛转讜爪讗讜转
43
  return {str(i): float(pred[0][i]) for i in range(10)}
44
 
45
+ # 讬爪讬专转 诪诪砖拽
46
  demo = gr.Interface(
47
  fn=predict_digit,
48
  inputs=[
 
50
  label="draw some digit",
51
  height=400,
52
  width=400,
53
+ brush_radius=8.0,
54
  interactive=True
55
  )
56
  ],
57
  outputs=gr.Label(num_top_classes=3),
58
+ title="MNIST Digit Recognition",
59
+ description="Draw a digit (0-9) and click submit to see the prediction",
60
  allow_flagging="never"
61
  )
62