import gradio as gr import tensorflow as tf import numpy as np from PIL import Image import os from datasets import load_dataset import random # Load model try: model = tf.keras.models.load_model("saved_model/Sports_Balls_Classification.h5") except: # Fallback if model path is different in HF Spaces model = tf.keras.models.load_model("./saved_model/Sports_Balls_Classification.h5") # Class names CLASS_NAMES = [ "american_football", "baseball", "basketball", "billiard_ball", "bowling_ball", "cricket_ball", "football", "golf_ball", "hockey_ball", "hockey_puck", "rugby_ball", "shuttlecock", "table_tennis_ball", "tennis_ball", "volleyball" ] def preprocess_image(img, target_size=(225, 225)): """Preprocess image for model prediction""" if isinstance(img, str): img = Image.open(img) img = img.convert("RGB") img = img.resize(target_size) img_array = np.array(img).astype("float32") / 255.0 img_array = np.expand_dims(img_array, axis=0) return img_array def classify_sports_ball(image): try: # Preprocess input_tensor = preprocess_image(image) # Predict predictions = model.predict(input_tensor, verbose=0) probs = predictions[0] # Get top prediction class_idx = int(np.argmax(probs)) confidence = float(np.max(probs)) # Create prediction dictionary pred_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} # Sort by confidence pred_dict = dict(sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)) return pred_dict except Exception as e: return {"error": str(e)} def load_random_dataset_image(): try: dataset = load_dataset("AIOmarRehan/Sports-Balls", split="test", trust_remote_code=True) random_idx = random.randint(0, len(dataset) - 1) sample = dataset[random_idx] # Handle different possible image column names image = None for col in ["image", "img", "photo", "picture"]: if col in sample: image = sample[col] break if image is None: # Try first column that might be an image for col, val in sample.items(): if isinstance(val, Image.Image): image = val break if image is None: return None if not isinstance(image, Image.Image): image = Image.open(image) return image except Exception as e: print(f"Error loading dataset: {e}") return None # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Sports Ball Classifier Upload an image of a sports ball to classify it. The model uses InceptionV3 transfer learning to identify 15 different types of sports balls. **Supported Sports Balls:** American Football, Baseball, Basketball, Billiard Ball, Bowling Ball, Cricket Ball, Football, Golf Ball, Hockey Ball, Hockey Puck, Rugby Ball, Shuttlecock, Table Tennis Ball, Tennis Ball, Volleyball """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Sports Ball Image", scale=1 ) with gr.Row(): submit_button = gr.Button("Classify", variant="primary", scale=2) random_button = gr.Button("Random Dataset", variant="secondary", scale=1) with gr.Column(): output = gr.Label(label="Prediction Confidence", num_top_classes=5) with gr.Row(): gr.Markdown( """ ### How to Use: 1. Upload or drag-and-drop an image containing a sports ball 2. Click the 'Classify' button 3. View the prediction results with confidence scores ### Model Details: - Architecture: InceptionV3 (transfer learning from ImageNet) - Training: Two-stage training (feature extraction + fine-tuning) - Accuracy: High performance across all 15 sports ball classes - Preprocessing: Automatic image resizing, normalization, and enhancement """ ) with gr.Row(): gr.Examples( examples=[], inputs=image_input, label="Example Images (Available)", run_on_click=False ) # Connect button to function submit_button.click(fn=classify_sports_ball, inputs=image_input, outputs=output) random_button.click(fn=load_random_dataset_image, outputs=image_input).then( fn=classify_sports_ball, inputs=image_input, outputs=output ) # Also allow pressing Enter on image upload image_input.change(fn=classify_sports_ball, inputs=image_input, outputs=output) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )