AIOmarRehan commited on
Commit
91b1e2b
·
verified ·
1 Parent(s): 06da7a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -160
app.py CHANGED
@@ -1,160 +1,160 @@
1
- import gradio as gr
2
- import tensorflow as tf
3
- import numpy as np
4
- from PIL import Image
5
- import os
6
- from datasets import load_dataset
7
- import random
8
-
9
- # Load model
10
- try:
11
- model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5")
12
- except:
13
- # Fallback if model path is different in HF Spaces
14
- model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5")
15
-
16
- # Class names
17
- CLASS_NAMES = [
18
- "american_football", "baseball", "basketball", "billiard_ball",
19
- "bowling_ball", "cricket_ball", "football", "golf_ball",
20
- "hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock",
21
- "table_tennis_ball", "tennis_ball", "volleyball"
22
- ]
23
-
24
- def preprocess_image(img, target_size=(225, 225)):
25
- """Preprocess image for model prediction"""
26
- if isinstance(img, str):
27
- img = Image.open(img)
28
-
29
- img = img.convert("RGB")
30
- img = img.resize(target_size)
31
- img_array = np.array(img).astype("float32") / 255.0
32
- img_array = np.expand_dims(img_array, axis=0)
33
- return img_array
34
-
35
- def classify_sports_ball(image):
36
- """Classify sports ball in image"""
37
- try:
38
- # Preprocess
39
- input_tensor = preprocess_image(image)
40
-
41
- # Predict
42
- predictions = model.predict(input_tensor, verbose=0)
43
- probs = predictions[0]
44
-
45
- # Get top prediction
46
- class_idx = int(np.argmax(probs))
47
- confidence = float(np.max(probs))
48
-
49
- # Create prediction dictionary
50
- pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
51
-
52
- # Sort by confidence
53
- pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True))
54
-
55
- return pred_dict
56
-
57
- except Exception as e:
58
- return {"error": str(e)}
59
-
60
- def load_random_dataset_image():
61
- """Load a random image from HuggingFace dataset"""
62
- try:
63
- dataset = load_dataset("Omarinooooo/test", split="train", trust_remote_code=True)
64
- random_idx = random.randint(0, len(dataset) - 1)
65
- sample = dataset[random_idx]
66
-
67
- # Handle different possible image column names
68
- image = None
69
- for col in ["image", "img", "photo", "picture"]:
70
- if col in sample:
71
- image = sample[col]
72
- break
73
-
74
- if image is None:
75
- # Try first column that might be an image
76
- for col, val in sample.items():
77
- if isinstance(val, Image.Image):
78
- image = val
79
- break
80
-
81
- if image is None:
82
- return None
83
-
84
- if not isinstance(image, Image.Image):
85
- image = Image.open(image)
86
-
87
- return image
88
-
89
- except Exception as e:
90
- print(f"Error loading dataset: {e}")
91
- return None
92
-
93
- # Create Gradio interface
94
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
95
- gr.Markdown(
96
- """
97
- # Sports Ball Classifier
98
-
99
- Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning
100
- to identify 15 different types of sports balls.
101
-
102
- **Supported Sports Balls:**
103
- American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football,
104
- Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball
105
- """
106
- )
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- image_input = gr.Image(
111
- type="pil",
112
- label="Upload Sports Ball Image",
113
- scale=1
114
- )
115
- with gr.Row():
116
- submit_button = gr.Button("Classify", variant="primary", scale=2)
117
- random_button = gr.Button("Random Dataset", variant="secondary", scale=1)
118
-
119
- with gr.Column():
120
- output = gr.Label(label="Prediction Confidence", num_top_classes=5)
121
-
122
- with gr.Row():
123
- gr.Markdown(
124
- """
125
- ### How to Use:
126
- 1. Upload or drag-and-drop an image containing a sports ball
127
- 2. Click the 'Classify' button
128
- 3. View the prediction results with confidence scores
129
-
130
- ### Model Details:
131
- - Architecture: InceptionV3 (transfer learning from ImageNet)
132
- - Training: Two-stage training (feature extraction + fine-tuning)
133
- - Accuracy: High performance across all 15 sports ball classes
134
- - Preprocessing: Automatic image resizing, normalization, and enhancement
135
- """
136
- )
137
-
138
- with gr.Row():
139
- gr.Examples(
140
- examples=[],
141
- inputs=image_input,
142
- label="Example Images (if available)",
143
- run_on_click=False
144
- )
145
-
146
- # Connect button to function
147
- submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output)
148
- random_button.click(fn=load_random_dataset_image, outputs=image_input).then(
149
- fn=classify_sports_ball, inputs=image_input, outputs=output
150
- )
151
-
152
- # Also allow pressing Enter on image upload
153
- image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output)
154
-
155
- if __name__ == "__main__":
156
- demo.launch(
157
- server_name="0.0.0.0",
158
- server_port=7860,
159
- share=False
160
- )
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+ from datasets import load_dataset
7
+ import random
8
+
9
+ # Load model
10
+ try:
11
+ model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5")
12
+ except:
13
+ # Fallback if model path is different in HF Spaces
14
+ model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5")
15
+
16
+ # Class names
17
+ CLASS_NAMES = [
18
+ "american_football", "baseball", "basketball", "billiard_ball",
19
+ "bowling_ball", "cricket_ball", "football", "golf_ball",
20
+ "hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock",
21
+ "table_tennis_ball", "tennis_ball", "volleyball"
22
+ ]
23
+
24
+ def preprocess_image(img, target_size=(225, 225)):
25
+ """Preprocess image for model prediction"""
26
+ if isinstance(img, str):
27
+ img = Image.open(img)
28
+
29
+ img = img.convert("RGB")
30
+ img = img.resize(target_size)
31
+ img_array = np.array(img).astype("float32") / 255.0
32
+ img_array = np.expand_dims(img_array, axis=0)
33
+ return img_array
34
+
35
+ def classify_sports_ball(image):
36
+ """Classify sports ball in image"""
37
+ try:
38
+ # Preprocess
39
+ input_tensor = preprocess_image(image)
40
+
41
+ # Predict
42
+ predictions = model.predict(input_tensor, verbose=0)
43
+ probs = predictions[0]
44
+
45
+ # Get top prediction
46
+ class_idx = int(np.argmax(probs))
47
+ confidence = float(np.max(probs))
48
+
49
+ # Create prediction dictionary
50
+ pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
51
+
52
+ # Sort by confidence
53
+ pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True))
54
+
55
+ return pred_dict
56
+
57
+ except Exception as e:
58
+ return {"error": str(e)}
59
+
60
+ def load_random_dataset_image():
61
+ """Load a random image from HuggingFace dataset"""
62
+ try:
63
+ dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True)
64
+ random_idx = random.randint(0, len(dataset) - 1)
65
+ sample = dataset[random_idx]
66
+
67
+ # Handle different possible image column names
68
+ image = None
69
+ for col in ["image", "img", "photo", "picture"]:
70
+ if col in sample:
71
+ image = sample[col]
72
+ break
73
+
74
+ if image is None:
75
+ # Try first column that might be an image
76
+ for col, val in sample.items():
77
+ if isinstance(val, Image.Image):
78
+ image = val
79
+ break
80
+
81
+ if image is None:
82
+ return None
83
+
84
+ if not isinstance(image, Image.Image):
85
+ image = Image.open(image)
86
+
87
+ return image
88
+
89
+ except Exception as e:
90
+ print(f"Error loading dataset: {e}")
91
+ return None
92
+
93
+ # Create Gradio interface
94
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
95
+ gr.Markdown(
96
+ """
97
+ # Sports Ball Classifier
98
+
99
+ Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning
100
+ to identify 15 different types of sports balls.
101
+
102
+ **Supported Sports Balls:**
103
+ American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football,
104
+ Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball
105
+ """
106
+ )
107
+
108
+ with gr.Row():
109
+ with gr.Column():
110
+ image_input = gr.Image(
111
+ type="pil",
112
+ label="Upload Sports Ball Image",
113
+ scale=1
114
+ )
115
+ with gr.Row():
116
+ submit_button = gr.Button("Classify", variant="primary", scale=2)
117
+ random_button = gr.Button("Random Dataset", variant="secondary", scale=1)
118
+
119
+ with gr.Column():
120
+ output = gr.Label(label="Prediction Confidence", num_top_classes=5)
121
+
122
+ with gr.Row():
123
+ gr.Markdown(
124
+ """
125
+ ### How to Use:
126
+ 1. Upload or drag-and-drop an image containing a sports ball
127
+ 2. Click the 'Classify' button
128
+ 3. View the prediction results with confidence scores
129
+
130
+ ### Model Details:
131
+ - Architecture: InceptionV3 (transfer learning from ImageNet)
132
+ - Training: Two-stage training (feature extraction + fine-tuning)
133
+ - Accuracy: High performance across all 15 sports ball classes
134
+ - Preprocessing: Automatic image resizing, normalization, and enhancement
135
+ """
136
+ )
137
+
138
+ with gr.Row():
139
+ gr.Examples(
140
+ examples=[],
141
+ inputs=image_input,
142
+ label="Example Images (if available)",
143
+ run_on_click=False
144
+ )
145
+
146
+ # Connect button to function
147
+ submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output)
148
+ random_button.click(fn=load_random_dataset_image, outputs=image_input).then(
149
+ fn=classify_sports_ball, inputs=image_input, outputs=output
150
+ )
151
+
152
+ # Also allow pressing Enter on image upload
153
+ image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output)
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch(
157
+ server_name="0.0.0.0",
158
+ server_port=7860,
159
+ share=False
160
+ )