Sathwik P commited on
Commit
c555121
Β·
1 Parent(s): 8b1191e

Add real-time progressive display for batch processing

Browse files
Files changed (1) hide show
  1. app.py +61 -18
app.py CHANGED
@@ -96,16 +96,14 @@ def predict_single_image(image):
96
 
97
  def predict_batch(images, csv_file):
98
  """
99
- Run inference on multiple images or CSV with image URLs (unlimited)
100
 
101
  Args:
102
  images: List of PIL Images or file paths (or None)
103
  csv_file: CSV file with image URLs (or None)
104
 
105
- Returns:
106
- tuple: (gallery_data, json_results)
107
- - gallery_data: List of (image, caption) tuples for Gradio Gallery
108
- - json_results: Dictionary with summary and individual results
109
  """
110
  # Check if CSV file is provided
111
  if csv_file is not None:
@@ -115,17 +113,18 @@ def predict_batch(images, csv_file):
115
 
116
  # Validate columns
117
  if 'Answer' not in df.columns or 'Questions - QuestionId β†’ Name' not in df.columns:
118
- return [], {
119
  "error": "CSV must have 'Answer' and 'Questions - QuestionId β†’ Name' columns",
120
  "total_images": 0,
121
  "results": []
122
  }
 
123
 
124
  results = []
125
  gallery_images = []
126
  total_start_time = time.time()
127
 
128
- # Process each row
129
  for idx, row in df.iterrows():
130
  try:
131
  # Get image URL and expected class
@@ -165,41 +164,65 @@ def predict_batch(images, csv_file):
165
  "inference_time_ms": None,
166
  "match": "βœ—"
167
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
 
169
  total_time = (time.time() - total_start_time) * 1000
170
-
171
- # Calculate accuracy
172
  successful = [r for r in results if "error" not in r]
173
  matched = [r for r in successful if r["match"] == "βœ“"]
174
 
175
- json_results = {
176
  "source": "CSV",
 
177
  "total_images": len(df),
 
178
  "successful_predictions": len(successful),
179
  "failed_predictions": len(results) - len(successful),
180
  "matched_predictions": len(matched),
181
  "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%",
182
  "total_processing_time_ms": f"{total_time:.2f}",
183
  "average_time_per_image_ms": f"{total_time / len(df):.2f}",
184
- "results": results
185
  }
186
 
187
- return gallery_images, json_results
188
 
189
  except Exception as e:
190
- return [], {
191
  "error": f"CSV processing error: {str(e)}",
192
  "total_images": 0,
193
  "results": []
194
  }
 
195
 
196
- # Process regular image uploads (no limit)
197
  if images is None or len(images) == 0:
198
- return [], {
199
  "error": "No images or CSV provided",
200
  "total_images": 0,
201
  "results": []
202
  }
 
203
 
204
  results = []
205
  gallery_images = []
@@ -249,20 +272,40 @@ def predict_batch(images, csv_file):
249
  gallery_images.append((error_img, f"#{idx + 1}: ERROR - {str(e)}"))
250
  except:
251
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
 
253
  total_time = (time.time() - total_start_time) * 1000
254
 
255
- json_results = {
256
  "source": "Direct Upload",
 
257
  "total_images": len(images),
 
258
  "successful_predictions": len([r for r in results if "error" not in r]),
259
  "failed_predictions": len([r for r in results if "error" in r]),
260
  "total_processing_time_ms": f"{total_time:.2f}",
261
  "average_time_per_image_ms": f"{total_time / len(images):.2f}",
262
- "results": results
263
  }
264
 
265
- return gallery_images, json_results
266
 
267
  # Create tabbed interface
268
  with gr.Blocks(title="🚌 Bus Inspection Classifier") as demo:
 
96
 
97
  def predict_batch(images, csv_file):
98
  """
99
+ Run inference on multiple images or CSV with image URLs (unlimited) with PROGRESSIVE DISPLAY
100
 
101
  Args:
102
  images: List of PIL Images or file paths (or None)
103
  csv_file: CSV file with image URLs (or None)
104
 
105
+ Yields:
106
+ tuple: (gallery_data, json_results) after each image is processed
 
 
107
  """
108
  # Check if CSV file is provided
109
  if csv_file is not None:
 
113
 
114
  # Validate columns
115
  if 'Answer' not in df.columns or 'Questions - QuestionId β†’ Name' not in df.columns:
116
+ yield [], {
117
  "error": "CSV must have 'Answer' and 'Questions - QuestionId β†’ Name' columns",
118
  "total_images": 0,
119
  "results": []
120
  }
121
+ return
122
 
123
  results = []
124
  gallery_images = []
125
  total_start_time = time.time()
126
 
127
+ # Process each row PROGRESSIVELY
128
  for idx, row in df.iterrows():
129
  try:
130
  # Get image URL and expected class
 
164
  "inference_time_ms": None,
165
  "match": "βœ—"
166
  })
167
+
168
+ # YIELD after each image - REAL-TIME UPDATE!
169
+ elapsed_time = (time.time() - total_start_time) * 1000
170
+ successful = [r for r in results if "error" not in r]
171
+ matched = [r for r in successful if r["match"] == "βœ“"]
172
+
173
+ json_results = {
174
+ "source": "CSV",
175
+ "status": f"Processing... {idx + 1}/{len(df)}",
176
+ "total_images": len(df),
177
+ "processed": idx + 1,
178
+ "successful_predictions": len(successful),
179
+ "failed_predictions": len(results) - len(successful),
180
+ "matched_predictions": len(matched),
181
+ "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%",
182
+ "elapsed_time_ms": f"{elapsed_time:.2f}",
183
+ "average_time_per_image_ms": f"{elapsed_time / (idx + 1):.2f}",
184
+ "results": results[-10:] # Show last 10 for performance
185
+ }
186
+
187
+ yield gallery_images.copy(), json_results
188
 
189
+ # Final yield with complete results
190
  total_time = (time.time() - total_start_time) * 1000
 
 
191
  successful = [r for r in results if "error" not in r]
192
  matched = [r for r in successful if r["match"] == "βœ“"]
193
 
194
+ final_results = {
195
  "source": "CSV",
196
+ "status": "βœ… Complete!",
197
  "total_images": len(df),
198
+ "processed": len(df),
199
  "successful_predictions": len(successful),
200
  "failed_predictions": len(results) - len(successful),
201
  "matched_predictions": len(matched),
202
  "accuracy": f"{(len(matched) / len(successful) * 100):.2f}%" if successful else "0%",
203
  "total_processing_time_ms": f"{total_time:.2f}",
204
  "average_time_per_image_ms": f"{total_time / len(df):.2f}",
205
+ "results": results # Full results at the end
206
  }
207
 
208
+ yield gallery_images, final_results
209
 
210
  except Exception as e:
211
+ yield [], {
212
  "error": f"CSV processing error: {str(e)}",
213
  "total_images": 0,
214
  "results": []
215
  }
216
+ return
217
 
218
+ # Process regular image uploads (no limit) PROGRESSIVELY
219
  if images is None or len(images) == 0:
220
+ yield [], {
221
  "error": "No images or CSV provided",
222
  "total_images": 0,
223
  "results": []
224
  }
225
+ return
226
 
227
  results = []
228
  gallery_images = []
 
272
  gallery_images.append((error_img, f"#{idx + 1}: ERROR - {str(e)}"))
273
  except:
274
  pass
275
+
276
+ # YIELD after each image - REAL-TIME UPDATE!
277
+ elapsed_time = (time.time() - total_start_time) * 1000
278
+
279
+ json_results = {
280
+ "source": "Direct Upload",
281
+ "status": f"Processing... {idx + 1}/{len(images)}",
282
+ "total_images": len(images),
283
+ "processed": idx + 1,
284
+ "successful_predictions": len([r for r in results if "error" not in r]),
285
+ "failed_predictions": len([r for r in results if "error" in r]),
286
+ "elapsed_time_ms": f"{elapsed_time:.2f}",
287
+ "average_time_per_image_ms": f"{elapsed_time / (idx + 1):.2f}",
288
+ "results": results[-10:] # Show last 10 for performance
289
+ }
290
+
291
+ yield gallery_images.copy(), json_results
292
 
293
+ # Final yield with complete results
294
  total_time = (time.time() - total_start_time) * 1000
295
 
296
+ final_results = {
297
  "source": "Direct Upload",
298
+ "status": "βœ… Complete!",
299
  "total_images": len(images),
300
+ "processed": len(images),
301
  "successful_predictions": len([r for r in results if "error" not in r]),
302
  "failed_predictions": len([r for r in results if "error" in r]),
303
  "total_processing_time_ms": f"{total_time:.2f}",
304
  "average_time_per_image_ms": f"{total_time / len(images):.2f}",
305
+ "results": results # Full results at the end
306
  }
307
 
308
+ yield gallery_images, final_results
309
 
310
  # Create tabbed interface
311
  with gr.Blocks(title="🚌 Bus Inspection Classifier") as demo: