GiladtheFixer commited on
Commit
9f8712e
verified
1 Parent(s): d4cf7c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -77
app.py CHANGED
@@ -1,92 +1,52 @@
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  import numpy as np
4
  from huggingface_hub import hf_hub_download
 
5
 
6
- # 诪谞讬注转 砖讬诪讜砖 讘-GPU - 诇驻注诪讬诐 讝讛 注讜讝专 讘诪拽专讛 砖诇 讘注讬讜转 转讗讬诪讜转
7
- import os
8
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
9
 
10
- try:
11
- # 讟注讬谞转 讛诪讜讚诇
12
- model_path = hf_hub_download(
13
- repo_id="GiladtheFixer/my_mnist_model",
14
- filename="mnist_model.keras"
15
- )
16
 
17
- # 讬爪讬专转 讛诪讜讚诇 诪讞讚砖
18
- model = tf.keras.Sequential([
19
- tf.keras.layers.Input(shape=(28, 28, 1)),
20
- tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
21
- tf.keras.layers.BatchNormalization(),
22
- tf.keras.layers.MaxPooling2D((2, 2)),
23
-
24
- tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
25
- tf.keras.layers.BatchNormalization(),
26
- tf.keras.layers.MaxPooling2D((2, 2)),
27
-
28
- tf.keras.layers.Flatten(),
29
- tf.keras.layers.Dense(256, activation='relu'),
30
- tf.keras.layers.BatchNormalization(),
31
- tf.keras.layers.Dropout(0.3),
32
-
33
- tf.keras.layers.Dense(128, activation='relu'),
34
- tf.keras.layers.BatchNormalization(),
35
- tf.keras.layers.Dropout(0.2),
36
-
37
- tf.keras.layers.Dense(10, activation='softmax')
38
- ])
39
 
40
- # 讟注讬谞转 讛诪砖拽讜诇讜转 诪讛诪讜讚诇 讛砖诪讜专
41
- model.load_weights(model_path)
42
 
43
- except Exception as e:
44
- print(f"Error loading model: {str(e)}")
45
- raise e
 
46
 
47
- def predict_digit(sketch_data):
48
- try:
49
- # 注讬讘讜讚 讛转诪讜谞讛
50
- img = sketch_data["composite"]
51
- alpha_channel = img[..., 3]
52
- img = alpha_channel / 255.0
53
-
54
- # 砖讬谞讜讬 讙讜讚诇 讜讛讜住驻转 诪诪讚
55
- resized = tf.image.resize(
56
- tf.expand_dims(img, -1),
57
- [28, 28],
58
- method='bilinear'
59
- )
60
-
61
- # 讛讻谞转 讛拽诇讟 诇诪讜讚诇
62
- input_data = tf.expand_dims(resized, 0)
63
-
64
- # 讞讬讝讜讬
65
- pred = model.predict(input_data, verbose=0)
66
-
67
- # 讛讞讝专转 讛转讜爪讗讜转
68
- return {str(i): float(pred[0][i]) for i in range(10)}
69
- except Exception as e:
70
- print(f"Error in prediction: {str(e)}")
71
- return {str(i): 0.0 for i in range(10)}
72
 
73
- # 讬爪讬专转 诪诪砖拽
74
- demo = gr.Interface(
75
  fn=predict_digit,
76
- inputs=[
77
- gr.Sketchpad(
78
- label="Draw a digit (0-9)",
79
- height=400,
80
- width=400,
81
- brush_radius=8.0,
82
- interactive=True
83
- )
84
- ],
85
  outputs=gr.Label(num_top_classes=3),
86
- title="MNIST Digit Recognition",
87
- description="Draw a digit (0-9) and click submit to see the prediction",
88
- allow_flagging="never"
 
89
  )
90
 
91
- if __name__ == "__main__":
92
- demo.launch()
 
1
  import gradio as gr
2
  import tensorflow as tf
3
+ from tensorflow import keras
4
  import numpy as np
5
  from huggingface_hub import hf_hub_download
6
+ import cv2
7
 
8
+ # 讛讜专讚转 讛诪讜讚诇 诪-Hugging Face Hub
9
+ model_path = hf_hub_download(repo_id="GiladtheFixer/my_mnist_model", filename="mnist_model.keras")
10
+ model = keras.models.load_model(model_path)
11
 
12
+ def preprocess_image(image):
13
+ # 讛诪专转 讛转诪讜谞讛 诇讙讜讜谞讬 讗驻讜专 讗诐 讛讬讗 爪讘注讜谞讬转
14
+ if len(image.shape) == 3:
15
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
 
16
 
17
+ # 砖讬谞讜讬 讙讜讚诇 讛转诪讜谞讛 诇-28x28
18
+ image = cv2.resize(image, (28, 28))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # 谞专诪讜诇 讛注专讻讬诐 诇-0-1
21
+ image = image.astype('float32') / 255.0
22
 
23
+ # 讛讜住驻转 诪诪讚 谞讜住祝 注讘讜专 讛注专讜抓
24
+ image = image.reshape(1, 28, 28, 1)
25
+
26
+ return image
27
 
28
+ def predict_digit(image):
29
+ # 注讬讘讜讚 诪拽讚讬诐 砖诇 讛转诪讜谞讛
30
+ processed_image = preprocess_image(image)
31
+
32
+ # 讞讬讝讜讬 讘讗诪爪注讜转 讛诪讜讚诇
33
+ prediction = model.predict(processed_image)
34
+
35
+ # 讬爪讬专转 诪讬诇讜谉 注诐 讛转讜爪讗讜转
36
+ result = {str(i): float(prediction[0][i]) for i in range(10)}
37
+
38
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # 讬爪讬专转 诪诪砖拽 Gradio
41
+ iface = gr.Interface(
42
  fn=predict_digit,
43
+ inputs=gr.Image(shape=(28, 28), image_mode="L", source="canvas", tool="pencil"),
 
 
 
 
 
 
 
 
44
  outputs=gr.Label(num_top_classes=3),
45
+ title="讝讬讛讜讬 住驻专讜转 讘讻转讘 讬讚",
46
+ description="爪讬讬专 住驻专讛 讘讬谉 0 诇-9 讜讛诪讜讚诇 讬讝讛讛 讗讜转讛",
47
+ examples=[],
48
+ theme=gr.themes.Default()
49
  )
50
 
51
+ # 讛驻注诇转 讛讗驻诇讬拽爪讬讛
52
+ iface.launch()