DrNikDJ commited on
Commit
e96b839
·
verified ·
1 Parent(s): 737135b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # 1. Load the model
7
+ # Ensure 'my_cifar_model.keras' is uploaded to the same Space directory
8
+ model = tf.keras.models.load_model('my_cifar_model.keras')
9
+
10
+ # 2. Define the class labels (Matches CIFAR-10 order)
11
+ labels = [
12
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
13
+ 'dog', 'frog', 'horse', 'ship', 'truck'
14
+ ]
15
+
16
+ def predict(img):
17
+ """
18
+ Takes an input image, processes it, and returns
19
+ the top classification probabilities.
20
+ """
21
+ if img is None:
22
+ return None
23
+
24
+ # Preprocessing:
25
+ # Convert to PIL Image if it's a numpy array, then resize to 32x32
26
+ img = Image.fromarray(img).resize((32, 32))
27
+
28
+ # Convert to array and normalize (0 to 1)
29
+ img_array = np.array(img).astype('float32') / 255.0
30
+
31
+ # Add batch dimension: (32, 32, 3) -> (1, 32, 32, 3)
32
+ img_array = np.expand_dims(img_array, axis=0)
33
+
34
+ # Perform prediction
35
+ predictions = model.predict(img_array).flatten()
36
+
37
+ # Apply Softmax to get probabilities (if not already in the model output)
38
+ score = tf.nn.softmax(predictions).numpy()
39
+
40
+ # Create a dictionary of {Label: Probability}
41
+ return {labels[i]: float(score[i]) for i in range(10)}
42
+
43
+ # 3. Create the Gradio Interface
44
+ demo = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Image(),
47
+ outputs=gr.Label(num_top_classes=3),
48
+ title="CIFAR-10 Image Classifier",
49
+ description="Upload an image and the model will predict its category among 10 classes.",
50
+ examples=["airplane_example.jpg", "cat_example.jpg"] # Optional: if you upload these files too
51
+ )
52
+
53
+ # 4. Launch the app
54
+ if __name__ == "__main__":
55
+ demo.launch()