barryallen16's picture
Update app.py
9f475de verified
import numpy as np
import tensorflow as tf
import json
import gradio as gr
from tensorflow.keras.applications.efficientnet import preprocess_input
# Load the inference model (without augmentation layers)
model = tf.keras.models.load_model('./indo_fashion_classification_model.keras')
# Load class labels
with open('class_labels.json', 'r') as f:
labels = json.load(f)
def predict_image(image):
if image is None:
return None
# Convert image to RGB if it's grayscale or has alpha channel
if len(image.shape) == 2: # Grayscale image
image = np.stack((image,) * 3, axis=-1)
elif image.shape[2] == 4: # RGBA image
image = image[:, :, :3] # Remove alpha channel
elif image.shape[2] == 1: # Single channel
image = np.concatenate([image] * 3, axis=-1)
# Resize to match model input shape
image = tf.image.resize(image, (224, 224))
# Preprocess for EfficientNet
image = preprocess_input(image)
# Add batch dimension and make prediction
image_batch = tf.expand_dims(image, 0)
predictions = model.predict(image_batch, verbose=0)
# Get top prediction
class_idx = np.argmax(predictions[0])
confidence = predictions[0][class_idx]
class_name = labels[str(class_idx)]
# Get top 3 predictions
top_3_indices = np.argsort(predictions[0])[-3:][::-1]
top_3_predictions = []
for idx in top_3_indices:
top_3_predictions.append({
'class': labels[str(idx)],
'confidence': f"{predictions[0][idx]:.2%}"
})
# Format output
result = f"**Predicted Class:** {class_name}\n**Confidence:** {confidence:.2%}\n\n"
result += "**Top 3 Predictions:**\n"
for i, pred in enumerate(top_3_predictions, 1):
result += f"{i}. {pred['class']}: {pred['confidence']}\n"
return result
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
gr.Markdown(
"""
# 🪷 Indian Ethnic Wear Classifier
Upload an image of Indian fashion attire to classify it using our EfficientNetB0 model trained on the Indo Fashion Dataset.
**Available Classes:**
""" + ", ".join(sorted(labels.values())) + """
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
type="numpy",
label="Upload Fashion Image",
height=400,
width=400,
sources=["upload", "webcam", "clipboard"],
show_download_button=True
)
gr.Examples(
examples=[
["example1.jpg"], # You can add example images here
["example2.jpg"],
["example3.jpg"]
],
inputs=input_image,
label="Try these examples (if available)"
)
with gr.Column(scale=1):
output_text = gr.Markdown(
label="Classification Results",
show_label=True
)
with gr.Row():
predict_btn = gr.Button("🎯 Classify Image", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
predict_btn.click(
fn=predict_image,
inputs=input_image,
outputs=output_text
)
input_image.change(
fn=predict_image,
inputs=input_image,
outputs=output_text
)
clear_btn.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[input_image, output_text]
)
gr.Markdown(
"""
---
**Note:**
- The model classifies images into 15 categories of Indian ethnic wear
- For best results, use clear, well-lit images focused on the clothing
- Supported formats: JPG, PNG, WebP
- Model: EfficientNetB0 trained on Indo Fashion Dataset
"""
)
if __name__ == "__main__":
demo.launch(
share=True,
show_error=True
)