Sathwik P commited on
Commit
4209af4
Β·
1 Parent(s): d3c78a5

Add batch processing support for up to 50 images

Browse files
Files changed (1) hide show
  1. app.py +181 -42
app.py CHANGED
@@ -46,9 +46,9 @@ def preprocess_image(image):
46
  img_final = np.transpose(img_norm, (2, 0, 1))
47
  return np.expand_dims(img_final, axis=0).astype(np.float32)
48
 
49
- def predict(image):
50
  """
51
- Run inference on input image and return predictions with metrics
52
 
53
  Args:
54
  image: PIL Image or numpy array
@@ -90,60 +90,199 @@ def predict(image):
90
  "inference_time_ms": f"{inference_time:.2f}"
91
  }
92
 
93
- # Create Gradio interface
94
- demo = gr.Interface(
95
- fn=predict,
96
- inputs=gr.Image(type="pil", label="Upload Bus Inspection Image"),
97
- outputs=gr.JSON(label="Prediction Results"),
98
- title="🚌 Bus Inspection Classifier - SigLIP v2",
99
- description="""
100
- Upload an image of a bus component for automated classification.
101
-
102
- **18 Categories:**
103
- AC Mat | Alco brake camera | Alco-brake device | Back windshield | Bus back side | Bus front side | Bus side | Cabin | Driver grooming | First aid kit | Floormats & POS | Front windshield | Hat rack | ITMS Device | Jack & Spare tyre | Luggage compartment | RFID Card | Seats
104
-
105
- **Returns:**
106
- - `class_name`: Predicted bus component category
107
- - `confidence`: Model confidence score (%)
108
- - `inference_time_ms`: Processing time in milliseconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
110
  ---
111
 
112
- ### πŸ”Œ API Usage
 
 
113
 
114
- **Python Example:**
115
  ```python
116
- import requests
117
- from PIL import Image
 
 
 
 
 
 
118
 
119
- # Your Hugging Face Space URL
120
- API_URL = "https://YOUR-USERNAME-bus-inspection.hf.space/api/predict"
 
121
 
122
- # Load your image
123
- image = Image.open("bus_image.jpg")
124
 
125
- # Make prediction
126
- response = requests.post(API_URL, files={"data": open("bus_image.jpg", "rb")})
127
- result = response.json()
128
 
129
- print(f"Class: {result['class_name']}")
130
- print(f"Confidence: {result['confidence']}")
131
- print(f"Time: {result['inference_time_ms']} ms")
 
 
132
  ```
133
 
134
- Using the **Gradio Client:**
135
  ```python
136
- from gradio_client import Client
137
 
138
- client = Client("YOUR-USERNAME/bus-inspection")
139
- result = client.predict("bus_image.jpg")
140
- print(result)
 
 
 
 
 
 
 
 
 
 
141
  ```
142
- """,
143
- examples=[],
144
- allow_flagging="never",
145
- analytics_enabled=False
146
- )
147
 
148
  if __name__ == "__main__":
149
  demo.launch()
 
46
  img_final = np.transpose(img_norm, (2, 0, 1))
47
  return np.expand_dims(img_final, axis=0).astype(np.float32)
48
 
49
+ def predict_single_image(image):
50
  """
51
+ Run inference on a single image
52
 
53
  Args:
54
  image: PIL Image or numpy array
 
90
  "inference_time_ms": f"{inference_time:.2f}"
91
  }
92
 
93
+ def predict_batch(images):
94
+ """
95
+ Run inference on multiple images (up to 50)
96
+
97
+ Args:
98
+ images: List of PIL Images or file paths
99
+
100
+ Returns:
101
+ dict: Summary and list of individual results
102
+ """
103
+ if images is None or len(images) == 0:
104
+ return {
105
+ "error": "No images provided",
106
+ "total_images": 0,
107
+ "results": []
108
+ }
109
+
110
+ # Limit to 50 images
111
+ if len(images) > 50:
112
+ return {
113
+ "error": "Maximum 50 images allowed",
114
+ "total_images": len(images),
115
+ "results": []
116
+ }
117
+
118
+ results = []
119
+ total_start_time = time.time()
120
+
121
+ for idx, img in enumerate(images):
122
+ try:
123
+ # Handle file path or PIL Image
124
+ if isinstance(img, str):
125
+ image = Image.open(img).convert('RGB')
126
+ elif isinstance(img, np.ndarray):
127
+ image = Image.fromarray(img).convert('RGB')
128
+ else:
129
+ image = img.convert('RGB')
130
+
131
+ # Get prediction
132
+ result = predict_single_image(image)
133
+ result["image_index"] = idx + 1
134
+ results.append(result)
135
+
136
+ except Exception as e:
137
+ results.append({
138
+ "image_index": idx + 1,
139
+ "error": str(e),
140
+ "class_name": None,
141
+ "confidence": None,
142
+ "inference_time_ms": None
143
+ })
144
+
145
+ total_time = (time.time() - total_start_time) * 1000
146
+
147
+ return {
148
+ "total_images": len(images),
149
+ "successful_predictions": len([r for r in results if "error" not in r]),
150
+ "failed_predictions": len([r for r in results if "error" in r]),
151
+ "total_processing_time_ms": f"{total_time:.2f}",
152
+ "average_time_per_image_ms": f"{total_time / len(images):.2f}",
153
+ "results": results
154
+ }
155
+
156
+ # Create tabbed interface
157
+ with gr.Blocks(title="🚌 Bus Inspection Classifier") as demo:
158
+ gr.Markdown("# 🚌 Bus Inspection Classifier - SigLIP v2")
159
+ gr.Markdown("""
160
+ Automated bus component classification using the **SigLIP v2** vision model.
161
+
162
+ **18 Categories:** AC Mat | Alco brake camera | Alco-brake device | Back windshield | Bus back side | Bus front side | Bus side | Cabin | Driver grooming | First aid kit | Floormats & POS | Front windshield | Hat rack | ITMS Device | Jack & Spare tyre | Luggage compartment | RFID Card | Seats
163
+ """)
164
+
165
+ with gr.Tabs():
166
+ # Single Image Tab
167
+ with gr.Tab("Single Image"):
168
+ gr.Markdown("### Upload a single bus inspection image")
169
+ with gr.Row():
170
+ with gr.Column():
171
+ single_input = gr.Image(type="pil", label="Upload Image")
172
+ single_button = gr.Button("Classify", variant="primary")
173
+ with gr.Column():
174
+ single_output = gr.JSON(label="Prediction Result")
175
+
176
+ single_button.click(
177
+ fn=predict_single_image,
178
+ inputs=single_input,
179
+ outputs=single_output
180
+ )
181
+
182
+ gr.Markdown("""
183
+ **Returns:**
184
+ - `class_name`: Predicted bus component category
185
+ - `confidence`: Model confidence score (%)
186
+ - `inference_time_ms`: Processing time in milliseconds
187
+ """)
188
+
189
+ # Batch Processing Tab
190
+ with gr.Tab("Batch Processing (Up to 50 Images)"):
191
+ gr.Markdown("### Upload multiple images for batch classification")
192
+ with gr.Row():
193
+ with gr.Column():
194
+ batch_input = gr.File(
195
+ file_count="multiple",
196
+ label="Upload Images (Max 50)",
197
+ file_types=["image"]
198
+ )
199
+ batch_button = gr.Button("Classify Batch", variant="primary")
200
+ with gr.Column():
201
+ batch_output = gr.JSON(label="Batch Results")
202
+
203
+ batch_button.click(
204
+ fn=predict_batch,
205
+ inputs=batch_input,
206
+ outputs=batch_output
207
+ )
208
+
209
+ gr.Markdown("""
210
+ **Returns:**
211
+ ```json
212
+ {
213
+ "total_images": 10,
214
+ "successful_predictions": 10,
215
+ "failed_predictions": 0,
216
+ "total_processing_time_ms": "456.78",
217
+ "average_time_per_image_ms": "45.68",
218
+ "results": [
219
+ {
220
+ "image_index": 1,
221
+ "class_name": "Bus front side",
222
+ "confidence": "98.45%",
223
+ "inference_time_ms": "43.21"
224
+ },
225
+ ...
226
+ ]
227
+ }
228
+ ```
229
+ """)
230
 
231
+ # API Documentation
232
+ gr.Markdown("""
233
  ---
234
 
235
+ ## πŸ”Œ API Usage
236
+
237
+ ### Single Image API
238
 
239
+ **Using Gradio Client (Python):**
240
  ```python
241
+ from gradio_client import Client
242
+
243
+ client = Client("Wicky/bus-inspection-classifier")
244
+ result = client.predict("bus_image.jpg", api_name="/predict")
245
+ print(result)
246
+ ```
247
+
248
+ ### Batch Processing API
249
 
250
+ **Using Gradio Client (Python):**
251
+ ```python
252
+ from gradio_client import Client
253
 
254
+ client = Client("Wicky/bus-inspection-classifier")
 
255
 
256
+ # Upload multiple images
257
+ image_files = ["img1.jpg", "img2.jpg", "img3.jpg"]
258
+ result = client.predict(image_files, api_name="/predict_batch")
259
 
260
+ print(f"Total: {result['total_images']}")
261
+ print(f"Successful: {result['successful_predictions']}")
262
+
263
+ for res in result['results']:
264
+ print(f"Image {res['image_index']}: {res['class_name']} ({res['confidence']})")
265
  ```
266
 
267
+ **Using Python Requests:**
268
  ```python
269
+ import requests
270
 
271
+ files = [
272
+ ('files', open('img1.jpg', 'rb')),
273
+ ('files', open('img2.jpg', 'rb')),
274
+ ('files', open('img3.jpg', 'rb'))
275
+ ]
276
+
277
+ response = requests.post(
278
+ "https://Wicky-bus-inspection-classifier.hf.space/api/predict_batch",
279
+ files=files
280
+ )
281
+
282
+ results = response.json()
283
+ print(results)
284
  ```
285
+ """)
 
 
 
 
286
 
287
  if __name__ == "__main__":
288
  demo.launch()