| | import gradio as gr |
| | import tensorflow as tf |
| | import numpy as np |
| | from PIL import Image |
| | import os |
| | from datasets import load_dataset |
| | import random |
| |
|
| | |
| | try: |
| | model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5") |
| | except: |
| | |
| | model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5") |
| |
|
| | |
| | CLASS_NAMES = [ |
| | "american_football", "baseball", "basketball", "billiard_ball", |
| | "bowling_ball", "cricket_ball", "football", "golf_ball", |
| | "hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock", |
| | "table_tennis_ball", "tennis_ball", "volleyball" |
| | ] |
| |
|
| | def preprocess_image(img, target_size=(225, 225)): |
| | """Preprocess image for model prediction""" |
| | if isinstance(img, str): |
| | img = Image.open(img) |
| | |
| | img = img.convert("RGB") |
| | img = img.resize(target_size) |
| | img_array = np.array(img).astype("float32") / 255.0 |
| | img_array = np.expand_dims(img_array, axis=0) |
| | return img_array |
| |
|
| | def classify_sports_ball(image): |
| | try: |
| | |
| | input_tensor = preprocess_image(image) |
| | |
| | |
| | predictions = model.predict(input_tensor, verbose=0) |
| | probs = predictions[0] |
| | |
| | |
| | class_idx = int(np.argmax(probs)) |
| | confidence = float(np.max(probs)) |
| | |
| | |
| | pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} |
| | |
| | |
| | pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)) |
| | |
| | return pred_dict |
| | |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | def load_random_dataset_image(): |
| | try: |
| | dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True) |
| | random_idx = random.randint(0, len(dataset) - 1) |
| | sample = dataset[random_idx] |
| | |
| | |
| | image = None |
| | for col in ["image", "img", "photo", "picture"]: |
| | if col in sample: |
| | image = sample[col] |
| | break |
| | |
| | if image is None: |
| | |
| | for col, val in sample.items(): |
| | if isinstance(val, Image.Image): |
| | image = val |
| | break |
| | |
| | if image is None: |
| | return None |
| | |
| | if not isinstance(image, Image.Image): |
| | image = Image.open(image) |
| | |
| | return image |
| | |
| | except Exception as e: |
| | print(f"Error loading dataset: {e}") |
| | return None |
| |
|
| | |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown( |
| | """ |
| | # Sports Ball Classifier |
| | |
| | Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning |
| | to identify 15 different types of sports balls. |
| | |
| | **Supported Sports Balls:** |
| | American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football, |
| | Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | image_input = gr.Image( |
| | type="pil", |
| | label="Upload Sports Ball Image", |
| | scale=1 |
| | ) |
| | with gr.Row(): |
| | submit_button = gr.Button("Classify", variant="primary", scale=2) |
| | random_button = gr.Button("Random Dataset", variant="secondary", scale=1) |
| | |
| | with gr.Column(): |
| | output = gr.Label(label="Prediction Confidence", num_top_classes=5) |
| | |
| | with gr.Row(): |
| | gr.Markdown( |
| | """ |
| | ### How to Use: |
| | 1. Upload or drag-and-drop an image containing a sports ball |
| | 2. Click the 'Classify' button |
| | 3. View the prediction results with confidence scores |
| | |
| | ### Model Details: |
| | - Architecture: InceptionV3 (transfer learning from ImageNet) |
| | - Training: Two-stage training (feature extraction + fine-tuning) |
| | - Accuracy: High performance across all 15 sports ball classes |
| | - Preprocessing: Automatic image resizing, normalization, and enhancement |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | gr.Examples( |
| | examples=[], |
| | inputs=image_input, |
| | label="Example Images (Available)", |
| | run_on_click=False |
| | ) |
| | |
| | |
| | submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output) |
| | random_button.click(fn=load_random_dataset_image, outputs=image_input).then( |
| | fn=classify_sports_ball, inputs=image_input, outputs=output |
| | ) |
| | |
| | |
| | image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False |
| | ) |
| |
|