barryallen16 commited on
Commit
9f475de
·
verified ·
1 Parent(s): 0a2a6d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -46
app.py CHANGED
@@ -3,86 +3,103 @@ import tensorflow as tf
3
  import json
4
  import gradio as gr
5
  from tensorflow.keras.applications.efficientnet import preprocess_input
6
- from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, GlobalAveragePooling2D
7
- from tensorflow.keras import Model
8
 
9
- old_model = tf.keras.models.load_model('indo_fashion_classification_model.keras')
 
10
 
11
- with open('class_labels_new_keras.json', 'r') as f:
 
12
  labels = json.load(f)
13
 
14
- pretrained_model = tf.keras.applications.efficientnet.EfficientNetB0(
15
- input_shape=(224, 224, 3),
16
- include_top=False,
17
- weights='imagenet'
18
- )
19
- pretrained_model.trainable = False
20
-
21
- inputs = Input(shape=(224, 224, 3))
22
- x = pretrained_model(inputs, training=False)
23
- x = GlobalAveragePooling2D()(x)
24
- x = Dense(128, activation='relu')(x)
25
- x = BatchNormalization()(x)
26
- x = Dropout(0.45)(x)
27
- x = Dense(256, activation='relu')(x)
28
- x = BatchNormalization()(x)
29
- x = Dropout(0.45)(x)
30
- outputs = Dense(len(labels), activation='softmax', dtype='float32')(x)
31
-
32
- new_model = Model(inputs=inputs, outputs=outputs)
33
-
34
- new_model.set_weights(old_model.get_weights())
35
-
36
  def predict_image(image):
37
  if image is None:
38
  return None
39
 
40
- # Convert PIL image to numpy array and ensure RGB
41
- image = np.array(image)
42
- if image.ndim == 2: # Grayscale image
43
- image = np.stack([image] * 3, axis=-1) # Convert to RGB by stacking
44
- elif image.shape[-1] == 4: # RGBA image
45
- image = image[..., :3] # Remove alpha channel
 
46
 
47
- # Resize image to (224, 224)
48
  image = tf.image.resize(image, (224, 224))
49
- image = tf.expand_dims(image, 0) # Add batch dimension
50
- image = preprocess_input(image) # Apply EfficientNet preprocessing
51
 
52
- # Make prediction
53
- predictions = model.predict(image, verbose=0)
 
 
 
 
 
 
54
  class_idx = np.argmax(predictions[0])
55
  confidence = predictions[0][class_idx]
56
  class_name = labels[str(class_idx)]
57
 
58
- return f"**Predicted Class:** {class_name}\n**Confidence:** {confidence:.2%}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
 
60
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
61
  gr.Markdown(
62
  """
63
- # Indian ethnic wear classifier
64
- Upload an image of Indian fashion attire to classify it using the EfficientNetB0 model trained on the Indo Fashion Dataset.
65
 
66
- **Classes:** """ + ", ".join(labels.values()) + """
 
 
 
67
  """
68
  )
69
 
70
  with gr.Row():
71
  with gr.Column(scale=1):
72
  input_image = gr.Image(
73
- type="pil",
74
- label="Upload Image",
75
  height=400,
76
  width=400,
77
- sources=["upload", "webcam"]
 
 
 
 
 
 
 
 
 
 
 
78
  )
 
79
  with gr.Column(scale=1):
80
  output_text = gr.Markdown(
81
- label="Prediction",
82
  show_label=True
83
  )
84
 
85
- predict_btn = gr.Button("Classify Image", variant="primary", size="lg")
 
 
 
86
  predict_btn.click(
87
  fn=predict_image,
88
  inputs=input_image,
@@ -93,4 +110,27 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) a
93
  fn=predict_image,
94
  inputs=input_image,
95
  outputs=output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
 
3
  import json
4
  import gradio as gr
5
  from tensorflow.keras.applications.efficientnet import preprocess_input
 
 
6
 
7
+ # Load the inference model (without augmentation layers)
8
+ model = tf.keras.models.load_model('./indo_fashion_classification_model.keras')
9
 
10
+ # Load class labels
11
+ with open('class_labels.json', 'r') as f:
12
  labels = json.load(f)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def predict_image(image):
15
  if image is None:
16
  return None
17
 
18
+ # Convert image to RGB if it's grayscale or has alpha channel
19
+ if len(image.shape) == 2: # Grayscale image
20
+ image = np.stack((image,) * 3, axis=-1)
21
+ elif image.shape[2] == 4: # RGBA image
22
+ image = image[:, :, :3] # Remove alpha channel
23
+ elif image.shape[2] == 1: # Single channel
24
+ image = np.concatenate([image] * 3, axis=-1)
25
 
26
+ # Resize to match model input shape
27
  image = tf.image.resize(image, (224, 224))
 
 
28
 
29
+ # Preprocess for EfficientNet
30
+ image = preprocess_input(image)
31
+
32
+ # Add batch dimension and make prediction
33
+ image_batch = tf.expand_dims(image, 0)
34
+ predictions = model.predict(image_batch, verbose=0)
35
+
36
+ # Get top prediction
37
  class_idx = np.argmax(predictions[0])
38
  confidence = predictions[0][class_idx]
39
  class_name = labels[str(class_idx)]
40
 
41
+ # Get top 3 predictions
42
+ top_3_indices = np.argsort(predictions[0])[-3:][::-1]
43
+ top_3_predictions = []
44
+
45
+ for idx in top_3_indices:
46
+ top_3_predictions.append({
47
+ 'class': labels[str(idx)],
48
+ 'confidence': f"{predictions[0][idx]:.2%}"
49
+ })
50
+
51
+ # Format output
52
+ result = f"**Predicted Class:** {class_name}\n**Confidence:** {confidence:.2%}\n\n"
53
+ result += "**Top 3 Predictions:**\n"
54
+ for i, pred in enumerate(top_3_predictions, 1):
55
+ result += f"{i}. {pred['class']}: {pred['confidence']}\n"
56
+
57
+ return result
58
 
59
+ # Create the Gradio interface
60
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
61
  gr.Markdown(
62
  """
63
+ # 🪷 Indian Ethnic Wear Classifier
 
64
 
65
+ Upload an image of Indian fashion attire to classify it using our EfficientNetB0 model trained on the Indo Fashion Dataset.
66
+
67
+ **Available Classes:**
68
+ """ + ", ".join(sorted(labels.values())) + """
69
  """
70
  )
71
 
72
  with gr.Row():
73
  with gr.Column(scale=1):
74
  input_image = gr.Image(
75
+ type="numpy",
76
+ label="Upload Fashion Image",
77
  height=400,
78
  width=400,
79
+ sources=["upload", "webcam", "clipboard"],
80
+ show_download_button=True
81
+ )
82
+
83
+ gr.Examples(
84
+ examples=[
85
+ ["example1.jpg"], # You can add example images here
86
+ ["example2.jpg"],
87
+ ["example3.jpg"]
88
+ ],
89
+ inputs=input_image,
90
+ label="Try these examples (if available)"
91
  )
92
+
93
  with gr.Column(scale=1):
94
  output_text = gr.Markdown(
95
+ label="Classification Results",
96
  show_label=True
97
  )
98
 
99
+ with gr.Row():
100
+ predict_btn = gr.Button("🎯 Classify Image", variant="primary", size="lg")
101
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
102
+
103
  predict_btn.click(
104
  fn=predict_image,
105
  inputs=input_image,
 
110
  fn=predict_image,
111
  inputs=input_image,
112
  outputs=output_text
113
+ )
114
+
115
+ clear_btn.click(
116
+ fn=lambda: (None, ""),
117
+ inputs=[],
118
+ outputs=[input_image, output_text]
119
+ )
120
+
121
+ gr.Markdown(
122
+ """
123
+ ---
124
+ **Note:**
125
+ - The model classifies images into 15 categories of Indian ethnic wear
126
+ - For best results, use clear, well-lit images focused on the clothing
127
+ - Supported formats: JPG, PNG, WebP
128
+ - Model: EfficientNetB0 trained on Indo Fashion Dataset
129
+ """
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch(
134
+ share=True,
135
+ show_error=True
136
  )