AIOmarRehan's picture
Update app.py
aaaa2cf verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import os
from datasets import load_dataset
import random
# Load model
try:
model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5")
except:
# Fallback if model path is different in HF Spaces
model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5")
# Class names
CLASS_NAMES = [
"american_football", "baseball", "basketball", "billiard_ball",
"bowling_ball", "cricket_ball", "football", "golf_ball",
"hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock",
"table_tennis_ball", "tennis_ball", "volleyball"
]
def preprocess_image(img, target_size=(225, 225)):
"""Preprocess image for model prediction"""
if isinstance(img, str):
img = Image.open(img)
img = img.convert("RGB")
img = img.resize(target_size)
img_array = np.array(img).astype("float32") / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
def classify_sports_ball(image):
try:
# Preprocess
input_tensor = preprocess_image(image)
# Predict
predictions = model.predict(input_tensor, verbose=0)
probs = predictions[0]
# Get top prediction
class_idx = int(np.argmax(probs))
confidence = float(np.max(probs))
# Create prediction dictionary
pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
# Sort by confidence
pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True))
return pred_dict
except Exception as e:
return {"error": str(e)}
def load_random_dataset_image():
try:
dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True)
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Handle different possible image column names
image = None
for col in ["image", "img", "photo", "picture"]:
if col in sample:
image = sample[col]
break
if image is None:
# Try first column that might be an image
for col, val in sample.items():
if isinstance(val, Image.Image):
image = val
break
if image is None:
return None
if not isinstance(image, Image.Image):
image = Image.open(image)
return image
except Exception as e:
print(f"Error loading dataset: {e}")
return None
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Sports Ball Classifier
Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning
to identify 15 different types of sports balls.
**Supported Sports Balls:**
American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football,
Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Sports Ball Image",
scale=1
)
with gr.Row():
submit_button = gr.Button("Classify", variant="primary", scale=2)
random_button = gr.Button("Random Dataset", variant="secondary", scale=1)
with gr.Column():
output = gr.Label(label="Prediction Confidence", num_top_classes=5)
with gr.Row():
gr.Markdown(
"""
### How to Use:
1. Upload or drag-and-drop an image containing a sports ball
2. Click the 'Classify' button
3. View the prediction results with confidence scores
### Model Details:
- Architecture: InceptionV3 (transfer learning from ImageNet)
- Training: Two-stage training (feature extraction + fine-tuning)
- Accuracy: High performance across all 15 sports ball classes
- Preprocessing: Automatic image resizing, normalization, and enhancement
"""
)
with gr.Row():
gr.Examples(
examples=[],
inputs=image_input,
label="Example Images (Available)",
run_on_click=False
)
# Connect button to function
submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output)
random_button.click(fn=load_random_dataset_image, outputs=image_input).then(
fn=classify_sports_ball, inputs=image_input, outputs=output
)
# Also allow pressing Enter on image upload
image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)