File size: 5,210 Bytes
91b1e2b aaaa2cf 91b1e2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import os
from datasets import load_dataset
import random
# Load model
try:
model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5")
except:
# Fallback if model path is different in HF Spaces
model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5")
# Class names
CLASS_NAMES = [
"american_football", "baseball", "basketball", "billiard_ball",
"bowling_ball", "cricket_ball", "football", "golf_ball",
"hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock",
"table_tennis_ball", "tennis_ball", "volleyball"
]
def preprocess_image(img, target_size=(225, 225)):
"""Preprocess image for model prediction"""
if isinstance(img, str):
img = Image.open(img)
img = img.convert("RGB")
img = img.resize(target_size)
img_array = np.array(img).astype("float32") / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def classify_sports_ball(image):
try:
# Preprocess
input_tensor = preprocess_image(image)
# Predict
predictions = model.predict(input_tensor, verbose=0)
probs = predictions[0]
# Get top prediction
class_idx = int(np.argmax(probs))
confidence = float(np.max(probs))
# Create prediction dictionary
pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
# Sort by confidence
pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True))
return pred_dict
except Exception as e:
return {"error": str(e)}
def load_random_dataset_image():
try:
dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True)
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Handle different possible image column names
image = None
for col in ["image", "img", "photo", "picture"]:
if col in sample:
image = sample[col]
break
if image is None:
# Try first column that might be an image
for col, val in sample.items():
if isinstance(val, Image.Image):
image = val
break
if image is None:
return None
if not isinstance(image, Image.Image):
image = Image.open(image)
return image
except Exception as e:
print(f"Error loading dataset: {e}")
return None
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Sports Ball Classifier
Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning
to identify 15 different types of sports balls.
**Supported Sports Balls:**
American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football,
Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Sports Ball Image",
scale=1
)
with gr.Row():
submit_button = gr.Button("Classify", variant="primary", scale=2)
random_button = gr.Button("Random Dataset", variant="secondary", scale=1)
with gr.Column():
output = gr.Label(label="Prediction Confidence", num_top_classes=5)
with gr.Row():
gr.Markdown(
"""
### How to Use:
1. Upload or drag-and-drop an image containing a sports ball
2. Click the 'Classify' button
3. View the prediction results with confidence scores
### Model Details:
- Architecture: InceptionV3 (transfer learning from ImageNet)
- Training: Two-stage training (feature extraction + fine-tuning)
- Accuracy: High performance across all 15 sports ball classes
- Preprocessing: Automatic image resizing, normalization, and enhancement
"""
)
with gr.Row():
gr.Examples(
examples=[],
inputs=image_input,
label="Example Images (Available)",
run_on_click=False
)
# Connect button to function
submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output)
random_button.click(fn=load_random_dataset_image, outputs=image_input).then(
fn=classify_sports_ball, inputs=image_input, outputs=output
)
# Also allow pressing Enter on image upload
image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
|