Sathwik P commited on
Commit
8b1191e
Β·
1 Parent(s): 317b639

Add unlimited CSV batch processing with given class matching

Browse files
Files changed (2) hide show
  1. app.py +110 -16
  2. requirements.txt +2 -0
app.py CHANGED
@@ -3,6 +3,9 @@ import onnxruntime as ort
3
  import numpy as np
4
  from PIL import Image
5
  import time
 
 
 
6
 
7
  # Load class names
8
  CLASS_NAMES = [
@@ -90,33 +93,114 @@ def predict_single_image(image):
90
  "inference_time_ms": f"{inference_time:.2f}"
91
  }
92
 
93
- def predict_batch(images):
 
94
  """
95
- Run inference on multiple images (up to 50)
96
 
97
  Args:
98
- images: List of PIL Images or file paths
 
99
 
100
  Returns:
101
  tuple: (gallery_data, json_results)
102
  - gallery_data: List of (image, caption) tuples for Gradio Gallery
103
  - json_results: Dictionary with summary and individual results
104
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if images is None or len(images) == 0:
106
  return [], {
107
- "error": "No images provided",
108
  "total_images": 0,
109
  "results": []
110
  }
111
 
112
- # Limit to 50 images
113
- if len(images) > 50:
114
- return [], {
115
- "error": "Maximum 50 images allowed",
116
- "total_images": len(images),
117
- "results": []
118
- }
119
-
120
  results = []
121
  gallery_images = []
122
  total_start_time = time.time()
@@ -169,6 +253,7 @@ def predict_batch(images):
169
  total_time = (time.time() - total_start_time) * 1000
170
 
171
  json_results = {
 
172
  "total_images": len(images),
173
  "successful_predictions": len([r for r in results if "error" not in r]),
174
  "failed_predictions": len([r for r in results if "error" in r]),
@@ -213,14 +298,23 @@ with gr.Blocks(title="🚌 Bus Inspection Classifier") as demo:
213
  """)
214
 
215
  # Batch Processing Tab
216
- with gr.Tab("Batch Processing (Up to 50 Images)"):
217
- gr.Markdown("### Upload multiple images for batch classification")
 
 
218
 
219
  batch_input = gr.File(
220
  file_count="multiple",
221
- label="Upload Images (Max 50)",
222
  file_types=["image"]
223
  )
 
 
 
 
 
 
 
224
  batch_button = gr.Button("Classify Batch", variant="primary", size="lg")
225
 
226
  # Gallery to show images with predictions
@@ -238,7 +332,7 @@ with gr.Blocks(title="🚌 Bus Inspection Classifier") as demo:
238
 
239
  batch_button.click(
240
  fn=predict_batch,
241
- inputs=batch_input,
242
  outputs=[batch_gallery, batch_output]
243
  )
244
 
 
3
  import numpy as np
4
  from PIL import Image
5
  import time
6
+ import pandas as pd
7
+ import requests
8
+ from io import BytesIO
9
 
10
  # Load class names
11
  CLASS_NAMES = [
 
93
  "inference_time_ms": f"{inference_time:.2f}"
94
  }
95
 
96
+
97
+ def predict_batch(images, csv_file):
98
  """
99
+ Run inference on multiple images or CSV with image URLs (unlimited)
100
 
101
  Args:
102
+ images: List of PIL Images or file paths (or None)
103
+ csv_file: CSV file with image URLs (or None)
104
 
105
  Returns:
106
  tuple: (gallery_data, json_results)
107
  - gallery_data: List of (image, caption) tuples for Gradio Gallery
108
  - json_results: Dictionary with summary and individual results
109
  """
110
+ # Check if CSV file is provided
111
+ if csv_file is not None:
112
+ try:
113
+ # Read CSV
114
+ df = pd.read_csv(csv_file)
115
+
116
+ # Validate columns
117
+ if 'Answer' not in df.columns or 'Questions - QuestionId β†’ Name' not in df.columns:
118
+ return [], {
119
+ "error": "CSV must have 'Answer' and 'Questions - QuestionId β†’ Name' columns",
120
+ "total_images": 0,
121
+ "results": []
122
+ }
123
+
124
+ results = []
125
+ gallery_images = []
126
+ total_start_time = time.time()
127
+
128
+ # Process each row
129
+ for idx, row in df.iterrows():
130
+ try:
131
+ # Get image URL and expected class
132
+ img_url = row['Answer']
133
+ given_class = row['Questions - QuestionId β†’ Name']
134
+
135
+ # Download image from URL
136
+ response = requests.get(img_url, timeout=10)
137
+ response.raise_for_status()
138
+ image = Image.open(BytesIO(response.content)).convert('RGB')
139
+
140
+ # Get prediction
141
+ result = predict_single_image(image)
142
+ result["image_index"] = idx + 1
143
+ result["given_class"] = given_class
144
+ result["image_url"] = img_url
145
+
146
+ # Check if matches
147
+ result["match"] = "βœ“" if given_class.lower() in result["class_name"].lower() or result["class_name"].lower() in given_class.lower() else "βœ—"
148
+
149
+ results.append(result)
150
+
151
+ # Create caption for gallery
152
+ caption = f"#{idx + 1}: {result['class_name']} {result['match']}\nGiven: {given_class}\nConf: {result['confidence']} | {result['inference_time_ms']}ms"
153
+
154
+ # Add to gallery
155
+ gallery_images.append((image, caption))
156
+
157
+ except Exception as e:
158
+ results.append({
159
+ "image_index": idx + 1,
160
+ "given_class": row.get('Questions - QuestionId β†’ Name', 'Unknown'),
161
+ "image_url": row.get('Answer', 'Unknown'),
162
+ "error": str(e),
163
+ "class_name": None,
164
+ "confidence": None,
165
+ "inference_time_ms": None,
166
+ "match": "βœ—"
167
+ })
168
+
169
+ total_time = (time.time() - total_start_time) * 1000
170
+
171
+ # Calculate accuracy
172
+ successful = [r for r in results if "error" not in r]
173
+ matched = [r for r in successful if r["match"] == "βœ“"]
174
+
175
+ json_results = {
176
+ "source": "CSV",
177
+ "total_images": len(df),
178
+ "successful_predictions": len(successful),
179
+ "failed_predictions": len(results) - len(successful),
180
+ "matched_predictions": len(matched),
181
+ "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%",
182
+ "total_processing_time_ms": f"{total_time:.2f}",
183
+ "average_time_per_image_ms": f"{total_time / len(df):.2f}",
184
+ "results": results
185
+ }
186
+
187
+ return gallery_images, json_results
188
+
189
+ except Exception as e:
190
+ return [], {
191
+ "error": f"CSV processing error: {str(e)}",
192
+ "total_images": 0,
193
+ "results": []
194
+ }
195
+
196
+ # Process regular image uploads (no limit)
197
  if images is None or len(images) == 0:
198
  return [], {
199
+ "error": "No images or CSV provided",
200
  "total_images": 0,
201
  "results": []
202
  }
203
 
 
 
 
 
 
 
 
 
204
  results = []
205
  gallery_images = []
206
  total_start_time = time.time()
 
253
  total_time = (time.time() - total_start_time) * 1000
254
 
255
  json_results = {
256
+ "source": "Direct Upload",
257
  "total_images": len(images),
258
  "successful_predictions": len([r for r in results if "error" not in r]),
259
  "failed_predictions": len([r for r in results if "error" in r]),
 
298
  """)
299
 
300
  # Batch Processing Tab
301
+ with gr.Tab("Batch Processing (Unlimited)"):
302
+ gr.Markdown("### Upload images OR CSV file with image URLs")
303
+ gr.Markdown("**Option 1:** Upload multiple images directly")
304
+ gr.Markdown("**Option 2:** Upload CSV with columns: `Questions - QuestionId β†’ Name` (given class) and `Answer` (image URL)")
305
 
306
  batch_input = gr.File(
307
  file_count="multiple",
308
+ label="Upload Images",
309
  file_types=["image"]
310
  )
311
+
312
+ csv_input = gr.File(
313
+ file_count="single",
314
+ label="OR Upload CSV with Image URLs",
315
+ file_types=[".csv"]
316
+ )
317
+
318
  batch_button = gr.Button("Classify Batch", variant="primary", size="lg")
319
 
320
  # Gallery to show images with predictions
 
332
 
333
  batch_button.click(
334
  fn=predict_batch,
335
+ inputs=[batch_input, csv_input],
336
  outputs=[batch_gallery, batch_output]
337
  )
338
 
requirements.txt CHANGED
@@ -2,3 +2,5 @@ gradio
2
  onnxruntime
3
  numpy
4
  Pillow
 
 
 
2
  onnxruntime
3
  numpy
4
  Pillow
5
+ pandas
6
+ requests