Bijan k commited on
Commit
91ed06c
·
1 Parent(s): db6a51d

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ import numpy as np
6
+
7
+ # Load the pre-trained model
8
+ try:
9
+ model = load_model("my_model.h5")
10
+ except OSError as e:
11
+ print(f"Error loading model: {e}")
12
+
13
+
14
+ def classify_image(image):
15
+ try:
16
+ # Preprocess the image
17
+ image_gray = tf.image.rgb_to_grayscale(image)
18
+
19
+ # Resize the image to 32x32 (only once)
20
+ image_tensor = tf.image.resize(image_gray, (32, 32))
21
+
22
+ # Cast to float32
23
+ image_tensor = tf.cast(image_tensor, tf.float32)
24
+
25
+ # Add batch dimension
26
+ image_tensor = tf.expand_dims(image_tensor, 0)
27
+
28
+ # Normalize the data
29
+ image_tensor = image_tensor / 255.0
30
+
31
+ # Get the prediction
32
+ predictions = model.predict(image_tensor)
33
+
34
+ # For top-3 output format compatible with Gradio
35
+ if predictions.shape[1] == 10: # For MNIST (0-9 digits)
36
+ class_names = {i: str(i) for i in range(10)}
37
+ top_indices = np.argsort(predictions[0])[-3:][::-1] # Top 3 indices
38
+ confidences = {
39
+ class_names[i]: float(predictions[0][i]) for i in top_indices
40
+ }
41
+ return confidences
42
+ else:
43
+ # Fallback to simple argmax if model output doesn't match expected format
44
+ return {str(predictions.argmax()): float(predictions.max())}
45
+
46
+ except Exception as e:
47
+ return {"Error": str(e)}
48
+
49
+
50
+ # Check if examples directory exists
51
+ example_list = []
52
+ if os.path.exists("examples"):
53
+ example_list = [
54
+ ["examples/" + example]
55
+ for example in os.listdir("examples")
56
+ if os.path.isfile(os.path.join("examples", example))
57
+ ]
58
+
59
+ title = "MNIST Model 98% Accuracy"
60
+ description = "Model trained on MNIST dataset using EfficientNet to classify handwritten digits with 98% accuracy"
61
+ article = "For source code, visit [my GitHub](https://github.com/Bijan-K/Tensorflow-MNIST-98Acc.git) (includes Gradio implementation and training code)."
62
+
63
+ interface = gr.Interface(
64
+ fn=classify_image,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs=gr.Label(num_top_classes=3),
67
+ examples=example_list,
68
+ title=title,
69
+ description=description,
70
+ article=article,
71
+ )
72
+
73
+ if __name__ == "__main__":
74
+ interface.launch()
examples/images (1).png ADDED
examples/images (2).png ADDED
examples/images.png ADDED
examples/sample_digit.png ADDED
my_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a1b5ca4fe6a8a40dc1e92b132417c44b2d84b472a99ebfdfecd4b6da9e936a3
3
+ size 49261456
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorflow>=2.6.0
2
+ efficientnet>=1.1.0
3
+ gradio>=2.8.0
4
+ pillow>=8.0.0
5
+ numpy>=1.19.5