Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,32 +1,48 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import tensorflow as tf
|
| 3 |
-
from tensorflow import keras
|
| 4 |
import numpy as np
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
|
| 7 |
-
#
|
| 8 |
model_path = hf_hub_download(
|
| 9 |
repo_id="GiladtheFixer/my_mnist_model",
|
| 10 |
filename="mnist_model.keras"
|
| 11 |
)
|
| 12 |
-
# 讛砖讬谞讜讬 讛讬讞讬讚 - 讛讜住驻谞讜 compile=False
|
| 13 |
-
model = tf.keras.models.load_model(model_path, compile=False)
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def predict_digit(sketch_data):
|
|
|
|
| 17 |
img = sketch_data["composite"]
|
| 18 |
alpha_channel = img[..., 3]
|
| 19 |
img = alpha_channel / 255.0
|
|
|
|
|
|
|
| 20 |
resized = tf.image.resize(
|
| 21 |
tf.expand_dims(img, -1),
|
| 22 |
[28, 28],
|
| 23 |
method='bilinear'
|
| 24 |
)
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
pred = model.predict(input_data, verbose=0)
|
|
|
|
|
|
|
| 28 |
return {str(i): float(pred[0][i]) for i in range(10)}
|
| 29 |
|
|
|
|
| 30 |
demo = gr.Interface(
|
| 31 |
fn=predict_digit,
|
| 32 |
inputs=[
|
|
@@ -34,13 +50,13 @@ demo = gr.Interface(
|
|
| 34 |
label="draw some digit",
|
| 35 |
height=400,
|
| 36 |
width=400,
|
| 37 |
-
|
| 38 |
interactive=True
|
| 39 |
)
|
| 40 |
],
|
| 41 |
outputs=gr.Label(num_top_classes=3),
|
| 42 |
-
title="
|
| 43 |
-
description="
|
| 44 |
allow_flagging="never"
|
| 45 |
)
|
| 46 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import tensorflow as tf
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
|
| 6 |
+
# 讟注讬谞转 讛诪讜讚诇
|
| 7 |
model_path = hf_hub_download(
|
| 8 |
repo_id="GiladtheFixer/my_mnist_model",
|
| 9 |
filename="mnist_model.keras"
|
| 10 |
)
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# 讟注讬谞转 讛诪讜讚诇 讘爪讜专讛 诪讜转讗诪转
|
| 13 |
+
model = tf.keras.models.load_model(
|
| 14 |
+
model_path,
|
| 15 |
+
compile=False,
|
| 16 |
+
custom_objects={
|
| 17 |
+
'RandomRotation': tf.keras.layers.RandomRotation,
|
| 18 |
+
'RandomZoom': tf.keras.layers.RandomZoom,
|
| 19 |
+
'RandomTranslation': tf.keras.layers.RandomTranslation
|
| 20 |
+
}
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
def predict_digit(sketch_data):
|
| 24 |
+
# 注讬讘讜讚 讛转诪讜谞讛
|
| 25 |
img = sketch_data["composite"]
|
| 26 |
alpha_channel = img[..., 3]
|
| 27 |
img = alpha_channel / 255.0
|
| 28 |
+
|
| 29 |
+
# 砖讬谞讜讬 讙讜讚诇 讜讛讜住驻转 诪诪讚
|
| 30 |
resized = tf.image.resize(
|
| 31 |
tf.expand_dims(img, -1),
|
| 32 |
[28, 28],
|
| 33 |
method='bilinear'
|
| 34 |
)
|
| 35 |
+
|
| 36 |
+
# 讛讻谞转 讛拽诇讟 诇诪讜讚诇
|
| 37 |
+
input_data = tf.expand_dims(resized, 0) # 讛讜住驻转 诪诪讚 讛讗爪讜讜讛
|
| 38 |
+
|
| 39 |
+
# 讞讬讝讜讬
|
| 40 |
pred = model.predict(input_data, verbose=0)
|
| 41 |
+
|
| 42 |
+
# 讛讞讝专转 讛转讜爪讗讜转
|
| 43 |
return {str(i): float(pred[0][i]) for i in range(10)}
|
| 44 |
|
| 45 |
+
# 讬爪讬专转 诪诪砖拽
|
| 46 |
demo = gr.Interface(
|
| 47 |
fn=predict_digit,
|
| 48 |
inputs=[
|
|
|
|
| 50 |
label="draw some digit",
|
| 51 |
height=400,
|
| 52 |
width=400,
|
| 53 |
+
brush_radius=8.0,
|
| 54 |
interactive=True
|
| 55 |
)
|
| 56 |
],
|
| 57 |
outputs=gr.Label(num_top_classes=3),
|
| 58 |
+
title="MNIST Digit Recognition",
|
| 59 |
+
description="Draw a digit (0-9) and click submit to see the prediction",
|
| 60 |
allow_flagging="never"
|
| 61 |
)
|
| 62 |
|